poseidon_bn128/
lib.rs

1use anyhow::Result;
2use num_bigint::BigUint;
3use num_traits::Num;
4use scalarff::Bn128FieldElement;
5use scalarff::FieldElement;
6use serde::Deserialize;
7use serde::Serialize;
8
9/// Representation for use with serde
10#[derive(Debug, Serialize, Deserialize)]
11#[allow(non_snake_case)]
12pub struct PoseidonParamsSerialized {
13    pub C: Vec<String>,
14    pub M: Vec<Vec<String>>,
15}
16
17/// Representation for use with the poseidon logic
18pub struct PoseidonParams {
19    pub c: Vec<Bn128FieldElement>,
20    pub m: Vec<Vec<Bn128FieldElement>>,
21    pub num_full_rounds: usize,
22    pub num_partial_rounds: usize,
23}
24
25// We'll store these constants in the binary itself
26// so accessing them at runtime is faster
27//
28// The downside is that constants for _all_ poseidons
29// are included even if only one is used. e.g. if you
30// only use poseidon2 you still have all the other
31// constants
32const POSEIDON_CONSTANTS: [&str; 16] = [
33    include_str!("params-json/1.json"),
34    include_str!("params-json/2.json"),
35    include_str!("params-json/3.json"),
36    include_str!("params-json/4.json"),
37    include_str!("params-json/5.json"),
38    include_str!("params-json/6.json"),
39    include_str!("params-json/7.json"),
40    include_str!("params-json/8.json"),
41    include_str!("params-json/9.json"),
42    include_str!("params-json/10.json"),
43    include_str!("params-json/11.json"),
44    include_str!("params-json/12.json"),
45    include_str!("params-json/13.json"),
46    include_str!("params-json/14.json"),
47    include_str!("params-json/15.json"),
48    include_str!("params-json/16.json"),
49];
50
51fn pow5(v: Bn128FieldElement) -> Bn128FieldElement {
52    let square = v * v;
53    let quad = square * square;
54    quad * v
55}
56
57fn mix(state: Vec<Bn128FieldElement>, params: &PoseidonParams) -> Vec<Bn128FieldElement> {
58    let mut out = vec![];
59    for i in 0..state.len() {
60        let mut o = Bn128FieldElement::zero();
61        #[allow(clippy::needless_range_loop)]
62        for j in 0..state.len() {
63            o += params.m[i][j] * state[j];
64        }
65        out.push(o);
66    }
67    out
68}
69
70/// Calculate the poseidon hash on the alt_bn128 curve. Instantiated with
71/// the same parameters as the circomlib implementation of poseidon. The first
72/// argument is the number of inputs, the second argument is the inputs.
73/// This is so that accidentally sized input vectors are caught more easily.
74/// The input vector is an unsized slice for more simple compatiblity.
75/// This function errors if input_count != input.len()
76pub fn poseidon(input_count: u8, input: &[Bn128FieldElement]) -> Result<Bn128FieldElement> {
77    if input.len() != usize::from(input_count) {
78        anyhow::bail!("expected {} inputs, received {}", input_count, input.len());
79    }
80    // constants are stored by number of inputs
81    let params = read_constants(input_count)?;
82    let t = usize::from(input_count + 1);
83
84    let mut state = [Bn128FieldElement::zero()]
85        .iter()
86        .chain(input)
87        .copied()
88        .collect::<Vec<Bn128FieldElement>>();
89
90    for x in 0..(params.num_full_rounds + params.num_partial_rounds) {
91        #[allow(clippy::needless_range_loop)]
92        for y in 0..state.len() {
93            state[y] += params.c[x * t + y];
94            if y == 0
95                || x < params.num_full_rounds / 2
96                || x >= params.num_full_rounds / 2 + params.num_partial_rounds
97            {
98                state[y] = pow5(state[y]);
99            }
100        }
101        state = mix(state, &params);
102    }
103    Ok(state[0])
104}
105
106/// Deserialize the constants from string (json) representation and return
107/// a structure with scalarff::FieldElement types
108pub fn read_constants(input_count: u8) -> Result<PoseidonParams> {
109    let params: PoseidonParamsSerialized =
110        serde_json::from_str(POSEIDON_CONSTANTS[usize::from(input_count - 1)])?;
111    let partial_round_counts = [
112        56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68,
113    ];
114    // TODO: move this into scalarff?
115    let hex_str_to_field_element = |x: &String| {
116        Ok(Bn128FieldElement::from_biguint(&BigUint::from_str_radix(
117            &x[2..],
118            16,
119        )?))
120    };
121    Ok(PoseidonParams {
122        num_full_rounds: 8,
123        num_partial_rounds: partial_round_counts[usize::from(input_count) - 1],
124        c: params
125            .C
126            .iter()
127            .map(hex_str_to_field_element)
128            .collect::<Result<_>>()?,
129        m: params
130            .M
131            .iter()
132            .map(|internal| internal.iter().map(hex_str_to_field_element).collect())
133            .collect::<Result<_>>()?,
134    })
135}
136
137#[cfg(test)]
138mod tests {
139    use std::fs::File;
140    use std::time::Instant;
141
142    use anyhow::Result;
143    use scalarff::Bn128FieldElement;
144    use scalarff::FieldElement;
145
146    #[test]
147    fn compare_hashes() -> Result<()> {
148        let f = File::open(format!("./src/test_hashes.json"))?;
149        let expected: Vec<Vec<String>> = serde_json::from_reader(f)?;
150        for i in 0..expected.len() {
151            let input_count = u8::try_from(i + 1)?;
152            let hash_count = expected[i].len();
153            let start = Instant::now();
154            for j in 0..hash_count {
155                let hash = super::poseidon(
156                    input_count,
157                    &vec![Bn128FieldElement::from(u64::try_from(j)?); usize::from(input_count)],
158                )?;
159                assert_eq!(hash.to_biguint().to_str_radix(16), expected[i][j][2..]);
160            }
161            let elapsed = start.elapsed();
162            println!(
163                "Calculated {hash_count} poseidon{input_count} hashes in: {:.2?}",
164                elapsed
165            );
166        }
167        Ok(())
168    }
169}