1use anyhow::Result;
2use num_bigint::BigUint;
3use num_traits::Num;
4use scalarff::Bn128FieldElement;
5use scalarff::FieldElement;
6use serde::Deserialize;
7use serde::Serialize;
8
9#[derive(Debug, Serialize, Deserialize)]
11#[allow(non_snake_case)]
12pub struct PoseidonParamsSerialized {
13 pub C: Vec<String>,
14 pub M: Vec<Vec<String>>,
15}
16
17pub struct PoseidonParams {
19 pub c: Vec<Bn128FieldElement>,
20 pub m: Vec<Vec<Bn128FieldElement>>,
21 pub num_full_rounds: usize,
22 pub num_partial_rounds: usize,
23}
24
25const POSEIDON_CONSTANTS: [&str; 16] = [
33 include_str!("params-json/1.json"),
34 include_str!("params-json/2.json"),
35 include_str!("params-json/3.json"),
36 include_str!("params-json/4.json"),
37 include_str!("params-json/5.json"),
38 include_str!("params-json/6.json"),
39 include_str!("params-json/7.json"),
40 include_str!("params-json/8.json"),
41 include_str!("params-json/9.json"),
42 include_str!("params-json/10.json"),
43 include_str!("params-json/11.json"),
44 include_str!("params-json/12.json"),
45 include_str!("params-json/13.json"),
46 include_str!("params-json/14.json"),
47 include_str!("params-json/15.json"),
48 include_str!("params-json/16.json"),
49];
50
51fn pow5(v: Bn128FieldElement) -> Bn128FieldElement {
52 let square = v * v;
53 let quad = square * square;
54 quad * v
55}
56
57fn mix(state: Vec<Bn128FieldElement>, params: &PoseidonParams) -> Vec<Bn128FieldElement> {
58 let mut out = vec![];
59 for i in 0..state.len() {
60 let mut o = Bn128FieldElement::zero();
61 #[allow(clippy::needless_range_loop)]
62 for j in 0..state.len() {
63 o += params.m[i][j] * state[j];
64 }
65 out.push(o);
66 }
67 out
68}
69
70pub fn poseidon(input_count: u8, input: &[Bn128FieldElement]) -> Result<Bn128FieldElement> {
77 if input.len() != usize::from(input_count) {
78 anyhow::bail!("expected {} inputs, received {}", input_count, input.len());
79 }
80 let params = read_constants(input_count)?;
82 let t = usize::from(input_count + 1);
83
84 let mut state = [Bn128FieldElement::zero()]
85 .iter()
86 .chain(input)
87 .copied()
88 .collect::<Vec<Bn128FieldElement>>();
89
90 for x in 0..(params.num_full_rounds + params.num_partial_rounds) {
91 #[allow(clippy::needless_range_loop)]
92 for y in 0..state.len() {
93 state[y] += params.c[x * t + y];
94 if y == 0
95 || x < params.num_full_rounds / 2
96 || x >= params.num_full_rounds / 2 + params.num_partial_rounds
97 {
98 state[y] = pow5(state[y]);
99 }
100 }
101 state = mix(state, ¶ms);
102 }
103 Ok(state[0])
104}
105
106pub fn read_constants(input_count: u8) -> Result<PoseidonParams> {
109 let params: PoseidonParamsSerialized =
110 serde_json::from_str(POSEIDON_CONSTANTS[usize::from(input_count - 1)])?;
111 let partial_round_counts = [
112 56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68,
113 ];
114 let hex_str_to_field_element = |x: &String| {
116 Ok(Bn128FieldElement::from_biguint(&BigUint::from_str_radix(
117 &x[2..],
118 16,
119 )?))
120 };
121 Ok(PoseidonParams {
122 num_full_rounds: 8,
123 num_partial_rounds: partial_round_counts[usize::from(input_count) - 1],
124 c: params
125 .C
126 .iter()
127 .map(hex_str_to_field_element)
128 .collect::<Result<_>>()?,
129 m: params
130 .M
131 .iter()
132 .map(|internal| internal.iter().map(hex_str_to_field_element).collect())
133 .collect::<Result<_>>()?,
134 })
135}
136
137#[cfg(test)]
138mod tests {
139 use std::fs::File;
140 use std::time::Instant;
141
142 use anyhow::Result;
143 use scalarff::Bn128FieldElement;
144 use scalarff::FieldElement;
145
146 #[test]
147 fn compare_hashes() -> Result<()> {
148 let f = File::open(format!("./src/test_hashes.json"))?;
149 let expected: Vec<Vec<String>> = serde_json::from_reader(f)?;
150 for i in 0..expected.len() {
151 let input_count = u8::try_from(i + 1)?;
152 let hash_count = expected[i].len();
153 let start = Instant::now();
154 for j in 0..hash_count {
155 let hash = super::poseidon(
156 input_count,
157 &vec![Bn128FieldElement::from(u64::try_from(j)?); usize::from(input_count)],
158 )?;
159 assert_eq!(hash.to_biguint().to_str_radix(16), expected[i][j][2..]);
160 }
161 let elapsed = start.elapsed();
162 println!(
163 "Calculated {hash_count} poseidon{input_count} hashes in: {:.2?}",
164 elapsed
165 );
166 }
167 Ok(())
168 }
169}