arkworks_native_gadgets/poseidon/
mod.rs

1// This file is part of Webb and was adapted from Arkworks.
2//
3// Copyright (C) 2021 Webb Technologies Inc.
4// SPDX-License-Identifier: Apache-2.0
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! A native implementation of the Poseidon hash function.
19//!
20//! The Poseidon hash function takes in a vector of elements of a prime field
21//! `F`, and outputs an element of `F`. This means it has the `FieldHasher`
22//! trait.
23//!
24//! The `width` parameter is the length of the input vector plus one.
25//! This is because before hashing, we append one entry of zero to the input
26//! vector.
27//!
28//! After this initial padding, Poseidon hashes the input vector through a
29//! number of cryptographic rounds, which can either be full rounds or partial
30//! rounds. (After the input vector begins to be processed, we call it the
31//! *state* vector).
32//!
33//! Each round is of the form ARC --> SB --> M, where
34//! - ARC stands for "add round constants."
35//! - SB stands for "S-box", (or "sub words") which means
36//! 	- raising **all** entries of the state vector to a fixed power alpha,
37//! 	in a full round.
38//! 	- raising **only the first** entry of the state vector to a fixed power
39//! 	alpha, in a partial round.
40//! - M stands for "mix layer," which means multiplying the state vector by a
41//!   fixed [MDS matrix](https://en.wikipedia.org/wiki/MDS_matrix).
42//!
43//! The output is the first entry of the state vector after the final round.
44//!
45//! The round constants and MDS matrix are precomputed and passed to Poseidon as
46//! parameters `round_keys` and `mds_matrix`, respectively.  There is a separate
47//! module `sbox` for setting the exponent alpha, which is passed to Poseidon as
48//! `sbox.0`.  Common values of alpha, which are supported in `sbox`, are
49//! 3, 5, 17, and -1: the default value is 5.
50//!
51//! Note that this is the *original* Poseidon hash function described in [the
52//! paper of Grassi, Khovratovich,
53//! Rechberger, Roy, and Schofnegger](https://eprint.iacr.org/2019/458.pdf),
54//! and NOT the optimized version described in
55//! [this page by Feng](https://hackmd.io/8MdoHwoKTPmQfZyIKEYWXQ).
56
57/// Importing dependencies
58use ark_crypto_primitives::Error;
59use ark_ff::{BigInteger, PrimeField};
60use ark_std::{error::Error as ArkError, io::Read, rand::Rng, string::ToString, vec::Vec};
61use sbox::PoseidonSbox;
62
63use super::{from_field_elements, to_field_elements};
64
65pub mod sbox;
66
67#[derive(Debug)]
68
69/// Error enum for the Poseidon hash function.  
70///
71/// See Variants for more information about when this error is thrown.
72pub enum PoseidonError {
73	/// Thrown if the S-box exponent alpha is not 3, 5, 17, or -1.
74	InvalidSboxSize(i8),
75
76	/// Thrown if the exponent alpha is -1 and the S-box tries to
77	/// take the inverse of zero.
78	ApplySboxFailed,
79
80	/// Thrown if the user attempts to input a vector whose length is
81	/// greater than the `width` parameter minus one.
82	InvalidInputs,
83}
84
85/// Error messages for PoseidonError.
86impl core::fmt::Display for PoseidonError {
87	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88		use PoseidonError::*;
89		let msg = match self {
90			InvalidSboxSize(s) => format!("sbox is not supported: {}", s),
91			ApplySboxFailed => "failed to apply sbox".to_string(),
92			InvalidInputs => "invalid inputs".to_string(),
93		};
94		write!(f, "{}", msg)
95	}
96}
97
98impl ArkError for PoseidonError {}
99
100/// Parameters for the Poseidon hash function.
101#[derive(Default, Clone, Debug)]
102pub struct PoseidonParameters<F: PrimeField> {
103	/// Round constants
104	pub round_keys: Vec<F>,
105
106	/// MDS matrix to apply in the mix layer.
107	pub mds_matrix: Vec<Vec<F>>,
108
109	/// Number of full rounds
110	pub full_rounds: u8,
111
112	/// Number of partial rounds
113	pub partial_rounds: u8,
114
115	/// Length of the input, in field elements, plus one zero element.
116	pub width: u8,
117
118	/// S-box to apply in the sub words layer.
119	pub sbox: PoseidonSbox,
120}
121
122impl<F: PrimeField> PoseidonParameters<F> {
123	pub fn new(
124		round_keys: Vec<F>,
125		mds_matrix: Vec<Vec<F>>,
126		full_rounds: u8,
127		partial_rounds: u8,
128		width: u8,
129		sbox: PoseidonSbox,
130	) -> Self {
131		Self {
132			round_keys,
133			mds_matrix,
134			width,
135			full_rounds,
136			partial_rounds,
137			sbox,
138		}
139	}
140
141	pub fn generate<R: Rng>(_rng: &mut R) -> Self {
142		unimplemented!();
143	}
144
145	/// The MDS matrices used for the Poseidon hash functions of widths 2-17
146	/// have been pre-computed, audited for security, and published.
147	/// If we wanted to generated our own MDS matrix we could write and use
148	/// this function, but for the moment we only use the published matrices,
149	/// so it remains unimplemented.
150	pub fn create_mds<R: Rng>(_rng: &mut R) -> Vec<Vec<F>> {
151		unimplemented!();
152	}
153
154	/// The round constants used for the Poseidon hash functions of widths 2-17
155	/// have been pre-computed, audited for security, and published.
156	/// If we wanted to generated our own round constants we could write and use
157	/// this function, but for the moment we only use the published round
158	/// constants, so it remains unimplemented.
159	pub fn create_round_keys<R: Rng>(_rng: &mut R) -> Vec<F> {
160		unimplemented!();
161	}
162
163	/// Encodes the PoseidonParameters struct as a bytestring (vector of u8
164	/// integers), in the following way: [width, number of full rounds, number
165	/// of partial rounds, S-box exponent alpha, round constant length, round
166	/// constants, MDS matrix length, MDS matrix]. Bytes are stored the
167	/// big-endian way.
168	pub fn to_bytes(&self) -> Vec<u8> {
169		let max_elt_size = F::BigInt::NUM_LIMBS * 8;
170		let mut buf: Vec<u8> = vec![];
171
172		buf.extend(&self.width.to_be_bytes());
173		buf.extend(&self.full_rounds.to_be_bytes());
174		buf.extend(&self.partial_rounds.to_be_bytes());
175		buf.extend(&self.sbox.0.to_be_bytes());
176
177		// Appends the length of the round constants to the encoding,
178		// allowing the decoder to parse the round constants.
179		let round_key_len = self.round_keys.len() * max_elt_size;
180		buf.extend_from_slice(&(round_key_len as u32).to_be_bytes());
181
182		// Appends the round constants to the encoding.
183		buf.extend_from_slice(&from_field_elements(&self.round_keys).unwrap());
184
185		// Suppose that M, the MDS matrix in the PoseidonParameters instance,
186		// is a t x t matrix.  Then the next block does the following:
187		// 1. Finds t by returning the length of the first entry of M,
188		// since M is a vector of vectors.
189		// 2. Appends t to the encoding.
190		// 3. Flattens M and appends it to the encoding.
191		let mut stored = false;
192		for i in 0..self.mds_matrix.len() {
193			if !stored {
194				// the number of bytes to read for each inner mds matrix vec
195				let inner_vec_len = self.mds_matrix[i].len() * max_elt_size;
196				buf.extend_from_slice(&(inner_vec_len as u32).to_be_bytes());
197				stored = true;
198			}
199
200			buf.extend_from_slice(&from_field_elements(&self.mds_matrix[i]).unwrap());
201		}
202		buf
203	}
204
205	/// Decodes a (valid) bytestring into a PoseidonParameters struct.
206	/// Throws an error if the bytestring is not valid, i.e., is not the result
207	/// of encoding an instance of PoseidonParameters with `to_bytes`.
208	pub fn from_bytes(mut bytes: &[u8]) -> Result<Self, Error> {
209		let mut width_u8 = [0u8; 1];
210		bytes.read_exact(&mut width_u8)?;
211		let width = u8::from_be_bytes(width_u8);
212
213		let mut full_rounds_len = [0u8; 1];
214		bytes.read_exact(&mut full_rounds_len)?;
215		let full_rounds = u8::from_be_bytes(full_rounds_len);
216
217		let mut partial_rounds_u8 = [0u8; 1];
218		bytes.read_exact(&mut partial_rounds_u8)?;
219		let partial_rounds = u8::from_be_bytes(partial_rounds_u8);
220
221		let mut exponentiation_u8 = [0u8; 1];
222		bytes.read_exact(&mut exponentiation_u8)?;
223		let exp = i8::from_be_bytes(exponentiation_u8);
224
225		let mut round_key_len = [0u8; 4];
226		bytes.read_exact(&mut round_key_len)?;
227
228		let round_key_len_usize: usize = u32::from_be_bytes(round_key_len) as usize;
229		let mut round_keys_buf = vec![0u8; round_key_len_usize];
230		bytes.read_exact(&mut round_keys_buf)?;
231
232		let round_keys = to_field_elements::<F>(&round_keys_buf)?;
233		let mut mds_matrix_inner_vec_len = [0u8; 4];
234		bytes.read_exact(&mut mds_matrix_inner_vec_len)?;
235
236		let inner_vec_len_usize = u32::from_be_bytes(mds_matrix_inner_vec_len) as usize;
237		let mut mds_matrix: Vec<Vec<F>> = vec![];
238		while !bytes.is_empty() {
239			let mut inner_vec_buf = vec![0u8; inner_vec_len_usize];
240			bytes.read_exact(&mut inner_vec_buf)?;
241
242			let inner_vec = to_field_elements::<F>(&inner_vec_buf)?;
243			mds_matrix.push(inner_vec);
244		}
245
246		Ok(Self {
247			round_keys,
248			mds_matrix,
249			width,
250			full_rounds,
251			partial_rounds,
252			sbox: PoseidonSbox(exp),
253		})
254	}
255}
256
257#[derive(Default, Clone, Debug)]
258
259/// The Poseidon hash function struct.  As a struct it contains just
260/// one field `params`, which holds an instance of the `PoseidonParameters`
261/// struct.  The real magic happens in the implementation of the `FieldHasher`
262/// trait, which is where the Poseidon hashing algorithm can be found.
263pub struct Poseidon<F: PrimeField> {
264	pub params: PoseidonParameters<F>,
265}
266
267impl<F: PrimeField> Poseidon<F> {
268	pub fn new(params: PoseidonParameters<F>) -> Self {
269		Poseidon { params }
270	}
271}
272
273/// A field hasher over a prime field `F` is any cryptographic hash function
274/// that takes in a vector of elements of `F` and outputs a single element
275/// of `F`.
276pub trait FieldHasher<F: PrimeField> {
277	fn hash(&self, inputs: &[F]) -> Result<F, PoseidonError>;
278
279	/// With this method we separate the special case when the length of the
280	/// input vector is 2, since hashing together two field elements is
281	/// particularly useful in Merkle trees.
282	fn hash_two(&self, left: &F, right: &F) -> Result<F, PoseidonError>;
283}
284
285/// The Poseidon hash algorithm.
286impl<F: PrimeField> FieldHasher<F> for Poseidon<F> {
287	fn hash(&self, inputs: &[F]) -> Result<F, PoseidonError> {
288		// Casting params to usize
289		let width = self.params.width as usize;
290		let partial_rounds = self.params.partial_rounds as usize;
291		let full_rounds = self.params.full_rounds as usize;
292
293		// Populate a state vector with 0 and then inputs, pad with zeros if necessary
294		if inputs.len() > width - 1 {
295			return Err(PoseidonError::InvalidInputs);
296		}
297		let mut state = vec![F::zero()];
298		for f in inputs {
299			state.push(*f);
300		}
301		while state.len() < width {
302			state.push(F::zero());
303		}
304
305		let nr = full_rounds + partial_rounds;
306		for r in 0..nr {
307			// Adding round constants
308			state.iter_mut().enumerate().for_each(|(i, a)| {
309				let c = self.params.round_keys[(r * width + i)];
310				a.add_assign(c);
311			});
312
313			let half_rounds = full_rounds / 2;
314
315			if r < half_rounds || r >= half_rounds + partial_rounds {
316				// Applying an exponentiation S-box to the *first* entry of the
317				// state vector, during partial rounds
318				state
319					.iter_mut()
320					.try_for_each(|a| self.params.sbox.apply_sbox(*a).map(|f| *a = f))?;
321			} else {
322				//Applying an exponentiation S-box to *all* entries of the state
323				// vector, during full rounds
324				state[0] = self.params.sbox.apply_sbox(state[0])?;
325			}
326
327			// Multiplying the state vector by the MDS matrix.
328			state = state
329				.iter()
330				.enumerate()
331				.map(|(i, _)| {
332					state.iter().enumerate().fold(F::zero(), |acc, (j, a)| {
333						let m = self.params.mds_matrix[i][j];
334						acc.add(m.mul(*a))
335					})
336				})
337				.collect();
338		}
339
340		Ok(state[0])
341	}
342
343	fn hash_two(&self, left: &F, right: &F) -> Result<F, PoseidonError> {
344		self.hash(&[*left, *right])
345	}
346}
347
348#[cfg(test)]
349pub mod test {
350	use crate::poseidon::{FieldHasher, Poseidon, PoseidonParameters, PoseidonSbox};
351	use ark_ed_on_bn254::Fq;
352	use ark_ff::{fields::Field, PrimeField};
353	use ark_std::{vec::Vec, One};
354
355	use arkworks_utils::{
356		bytes_matrix_to_f, bytes_vec_to_f, parse_vec, poseidon_params::setup_poseidon_params, Curve,
357	};
358
359	pub fn setup_params<F: PrimeField>(curve: Curve, exp: i8, width: u8) -> PoseidonParameters<F> {
360		let pos_data = setup_poseidon_params(curve, exp, width).unwrap();
361
362		let mds_f = bytes_matrix_to_f(&pos_data.mds);
363		let rounds_f = bytes_vec_to_f(&pos_data.rounds);
364
365		let pos = PoseidonParameters {
366			mds_matrix: mds_f,
367			round_keys: rounds_f,
368			full_rounds: pos_data.full_rounds,
369			partial_rounds: pos_data.partial_rounds,
370			sbox: PoseidonSbox(pos_data.exp),
371			width: pos_data.width,
372		};
373
374		pos
375	}
376
377	type PoseidonHasher = Poseidon<Fq>;
378	#[test]
379	fn test_width_3_circom_bn_254() {
380		let curve = Curve::Bn254;
381
382		let params = setup_params(curve, 5, 3);
383		let poseidon = PoseidonHasher::new(params);
384
385		// output from circomlib, and here is the code.
386		// ```js
387		// const { poseidon } = require('circomlib');
388		// console.log(poseidon([1, 2]).toString(16));
389		// ```
390		let res: Vec<Fq> = bytes_vec_to_f(
391			&parse_vec(vec![
392				"0x115cc0f5e7d690413df64c6b9662e9cf2a3617f2743245519e19607a4417189a",
393			])
394			.unwrap(),
395		);
396		let left_input = Fq::one();
397		let right_input = Fq::one().double();
398		let poseidon_res = poseidon.hash_two(&left_input, &right_input).unwrap();
399
400		assert_eq!(res[0], poseidon_res, "{} != {}", res[0], poseidon_res);
401
402		// test two with 32 bytes.
403		// these bytes are randomly generated.
404		// and tested as the following:
405		// ```js
406		// const left = "0x" + Buffer.from([
407		// 		0x06, 0x9c, 0x63, 0x81, 0xac, 0x0b, 0x96, 0x8e, 0x88, 0x1c,
408		// 		0x91, 0x3c, 0x17, 0xd8, 0x36, 0x06, 0x7f, 0xd1, 0x5f, 0x2c,
409		// 		0xc7, 0x9f, 0x90, 0x2c, 0x80, 0x70, 0xb3, 0x6d, 0x28, 0x66,
410		// 		0x17, 0xdd
411		// ]).toString("hex");
412		// const right = "0x" + Buffer.from([
413		// 		0xc3, 0x3b, 0x60, 0x04, 0x2f, 0x76, 0xc7, 0xfb, 0xd0, 0x5d,
414		// 		0xb7, 0x76, 0x23, 0xcb, 0x17, 0xb8, 0x1d, 0x49, 0x41, 0x4b,
415		// 		0x82, 0xe5, 0x6a, 0x2e, 0xc0, 0x18, 0xf7, 0xa5, 0x5c, 0x3f,
416		// 		0x30, 0x0b
417		// ]).toString("hex");
418		// console.log({
419		// 		hash: "0x" + poseidon([left, right])
420		// 						.toString(16)
421		// 						.padStart(64, "0")
422		// 		});
423		// ```
424		//
425		// Here we should read the data as Big Endian and
426		// then we convert it to little endian.
427		let aaa: &[u8] = &[
428			0x06, 0x9c, 0x63, 0x81, 0xac, 0x0b, 0x96, 0x8e, 0x88, 0x1c, 0x91, 0x3c, 0x17, 0xd8,
429			0x36, 0x06, 0x7f, 0xd1, 0x5f, 0x2c, 0xc7, 0x9f, 0x90, 0x2c, 0x80, 0x70, 0xb3, 0x6d,
430			0x28, 0x66, 0x17, 0xdd,
431		];
432		let left_input = Fq::from_be_bytes_mod_order(aaa);
433		let right_input = Fq::from_be_bytes_mod_order(&[
434			0xc3, 0x3b, 0x60, 0x04, 0x2f, 0x76, 0xc7, 0xfb, 0xd0, 0x5d, 0xb7, 0x76, 0x23, 0xcb,
435			0x17, 0xb8, 0x1d, 0x49, 0x41, 0x4b, 0x82, 0xe5, 0x6a, 0x2e, 0xc0, 0x18, 0xf7, 0xa5,
436			0x5c, 0x3f, 0x30, 0x0b,
437		]);
438		let res: Vec<Fq> = bytes_vec_to_f(
439			&parse_vec(vec![
440				"0x0a13ad844d3487ad3dbaf3876760eb971283d48333fa5a9e97e6ee422af9554b",
441			])
442			.unwrap(),
443		);
444		let poseidon_res = poseidon.hash_two(&left_input, &right_input).unwrap();
445		assert_eq!(res[0], poseidon_res, "{} != {}", res[0], poseidon_res);
446	}
447
448	#[test]
449	fn test_compare_hashes_with_circom_bn_254() {
450		let curve = Curve::Bn254;
451
452		let parameters2 = setup_params(curve, 5, 2);
453		let parameters4 = setup_params(curve, 5, 4);
454		let parameters5 = setup_params(curve, 5, 5);
455
456		let poseidon2 = Poseidon::new(parameters2);
457		let poseidon4 = Poseidon::new(parameters4);
458		let poseidon5 = Poseidon::new(parameters5);
459
460		let expected_public_key: Vec<Fq> = bytes_vec_to_f(
461			&parse_vec(vec![
462				"0x07a1f74bf9feda741e1e9099012079df28b504fc7a19a02288435b8e02ae21fa",
463			])
464			.unwrap(),
465		);
466
467		let private_key: Vec<Fq> = bytes_vec_to_f(
468			&parse_vec(vec![
469				"0xb2ac10dccfb5a5712d632464a359668bb513e80e9d145ab5a88381de83af1046",
470			])
471			.unwrap(),
472		);
473		// let input = private_key[0];
474
475		let computed_public_key = poseidon2.hash(&private_key).unwrap();
476		println!("poseidon_res = {:?}", computed_public_key);
477		//println!("expected_res = {:?}", res[0]);
478		assert_eq!(
479			expected_public_key[0], computed_public_key,
480			"{} != {}",
481			expected_public_key[0], computed_public_key
482		);
483
484		let chain_id: Vec<Fq> = bytes_vec_to_f(
485			&parse_vec(vec![
486				"0x0000000000000000000000000000000000000000000000000000000000007a69",
487			])
488			.unwrap(),
489		);
490		let amount: Vec<Fq> = bytes_vec_to_f(
491			&parse_vec(vec![
492				"0x0000000000000000000000000000000000000000000000000000000000989680",
493			])
494			.unwrap(),
495		);
496		let blinding: Vec<Fq> = bytes_vec_to_f(
497			&parse_vec(vec![
498				"0x00a668ba0dcb34960aca597f433d0d3289c753046afa26d97e1613148c05f2c0",
499			])
500			.unwrap(),
501		);
502
503		let expected_leaf: Vec<Fq> = bytes_vec_to_f(
504			&parse_vec(vec![
505				"0x15206d966a7fb3e3fbbb7f4d7b623ca1c7c9b5c6e6d0a3348df428189441a1e4",
506			])
507			.unwrap(),
508		);
509		let mut input = vec![chain_id[0]];
510		input.push(amount[0]);
511		input.push(expected_public_key[0]);
512		input.push(blinding[0]);
513		let computed_leaf = poseidon5.hash(&input).unwrap();
514
515		assert_eq!(
516			expected_leaf[0], computed_leaf,
517			"{} != {}",
518			expected_leaf[0], computed_leaf
519		);
520
521		let path_index: Vec<Fq> = bytes_vec_to_f(
522			&parse_vec(vec![
523				"0x0000000000000000000000000000000000000000000000000000000000000000",
524			])
525			.unwrap(),
526		);
527		let expected_nullifier: Vec<Fq> = bytes_vec_to_f(
528			&parse_vec(vec![
529				"0x21423c7374ce5b3574f04f92243449359ae3865bb8e34cb2b7b5e4187ba01fca",
530			])
531			.unwrap(),
532		);
533		let mut input = vec![expected_leaf[0]];
534		input.push(path_index[0]);
535		input.push(private_key[0]);
536
537		let computed_nullifier = poseidon4.hash(&input).unwrap();
538
539		assert_eq!(
540			expected_nullifier[0], computed_nullifier,
541			"{} != {}",
542			expected_nullifier[0], computed_nullifier
543		);
544	}
545
546	#[test]
547	fn test_parameter_to_and_from_bytes() {
548		let curve = Curve::Bn254;
549		let params = setup_params::<Fq>(curve, 5, 3);
550
551		let bytes = params.to_bytes();
552		let new_params: PoseidonParameters<Fq> = PoseidonParameters::from_bytes(&bytes).unwrap();
553		assert_eq!(bytes, new_params.to_bytes());
554	}
555}