poseidon_rs/
lib.rs

1extern crate rand;
2#[macro_use]
3extern crate ff;
4use ff::*;
5
6#[derive(PrimeField)]
7#[PrimeFieldModulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"]
8#[PrimeFieldGenerator = "7"]
9pub struct Fr(FrRepr);
10
11mod constants;
12
13#[derive(Debug)]
14pub struct Constants {
15    pub c: Vec<Vec<Fr>>,
16    pub m: Vec<Vec<Vec<Fr>>>,
17    pub n_rounds_f: usize,
18    pub n_rounds_p: Vec<usize>,
19}
20pub fn load_constants() -> Constants {
21    let (c_str, m_str) = constants::constants();
22    let mut c: Vec<Vec<Fr>> = Vec::new();
23    for i in 0..c_str.len() {
24        let mut cci: Vec<Fr> = Vec::new();
25        for j in 0..c_str[i].len() {
26            let b: Fr = Fr::from_str(c_str[i][j]).unwrap();
27            cci.push(b);
28        }
29        c.push(cci);
30    }
31    let mut m: Vec<Vec<Vec<Fr>>> = Vec::new();
32    for i in 0..m_str.len() {
33        let mut mi: Vec<Vec<Fr>> = Vec::new();
34        for j in 0..m_str[i].len() {
35            let mut mij: Vec<Fr> = Vec::new();
36            for k in 0..m_str[i][j].len() {
37                let b: Fr = Fr::from_str(m_str[i][j][k]).unwrap();
38                mij.push(b);
39            }
40            mi.push(mij);
41        }
42        m.push(mi);
43    }
44    Constants {
45        c,
46        m,
47        n_rounds_f: 8,
48        n_rounds_p: vec![
49            56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68,
50        ],
51    }
52}
53
54pub struct Poseidon {
55    constants: Constants,
56}
57impl Poseidon {
58    pub fn new() -> Poseidon {
59        Poseidon {
60            constants: load_constants(),
61        }
62    }
63    pub fn ark(&self, state: &mut Vec<Fr>, c: &Vec<Fr>, it: usize) {
64        for i in 0..state.len() {
65            state[i].add_assign(&c[it + i]);
66        }
67    }
68
69    pub fn sbox(&self, n_rounds_f: usize, n_rounds_p: usize, state: &mut Vec<Fr>, i: usize) {
70        if i < n_rounds_f / 2 || i >= n_rounds_f / 2 + n_rounds_p {
71            for j in 0..state.len() {
72                let aux = state[j];
73                state[j].square();
74                state[j].square();
75                state[j].mul_assign(&aux);
76            }
77        } else {
78            let aux = state[0];
79            state[0].square();
80            state[0].square();
81            state[0].mul_assign(&aux);
82        }
83    }
84
85    pub fn mix(&self, state: &Vec<Fr>, m: &Vec<Vec<Fr>>) -> Vec<Fr> {
86        let mut new_state: Vec<Fr> = Vec::new();
87        for i in 0..state.len() {
88            new_state.push(Fr::zero());
89            for j in 0..state.len() {
90                let mut mij = m[i][j];
91                mij.mul_assign(&state[j]);
92                new_state[i].add_assign(&mij);
93            }
94        }
95        new_state.clone()
96    }
97
98    pub fn hash(&self, inp: Vec<Fr>) -> Result<Fr, String> {
99        let t = inp.len() + 1;
100        // if inp.len() == 0 || inp.len() >= self.constants.n_rounds_p.len() - 1 {
101        if inp.is_empty() || inp.len() > self.constants.n_rounds_p.len() {
102            return Err("Wrong inputs length".to_string());
103        }
104        let n_rounds_f = self.constants.n_rounds_f.clone();
105        let n_rounds_p = self.constants.n_rounds_p[t - 2].clone();
106
107        let mut state = vec![Fr::zero(); t];
108        state[1..].clone_from_slice(&inp);
109
110        for i in 0..(n_rounds_f + n_rounds_p) {
111            self.ark(&mut state, &self.constants.c[t - 2], i * t);
112            self.sbox(n_rounds_f, n_rounds_p, &mut state, i);
113            state = self.mix(&state, &self.constants.m[t - 2]);
114        }
115
116        Ok(state[0])
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_ff() {
126        let a = Fr::from_repr(FrRepr::from(2)).unwrap();
127        assert_eq!(
128            "0000000000000000000000000000000000000000000000000000000000000002",
129            to_hex(&a)
130        );
131
132        let b: Fr = Fr::from_str(
133            "21888242871839275222246405745257275088548364400416034343698204186575808495619",
134        )
135        .unwrap();
136        assert_eq!(
137            "0000000000000000000000000000000000000000000000000000000000000002",
138            to_hex(&b)
139        );
140        assert_eq!(&a, &b);
141    }
142
143    // #[test]
144    // fn test_load_constants() {
145    //     let cons = load_constants();
146    //     assert_eq!(
147    //         cons.c[0][0].to_string(),
148    //         "Fr(0x09c46e9ec68e9bd4fe1faaba294cba38a71aa177534cdd1b6c7dc0dbd0abd7a7)"
149    //     );
150    //     assert_eq!(
151    //         cons.c[cons.c.len() - 1][0].to_string(),
152    //         "Fr(0x2088ce9534577bf38be7bc457f2756d558d66e0c07b9cc001a580bd42cda0e77)"
153    //     );
154    //     assert_eq!(
155    //         cons.m[0][0][0].to_string(),
156    //         "Fr(0x066f6f85d6f68a85ec10345351a23a3aaf07f38af8c952a7bceca70bd2af7ad5)"
157    //     );
158    //     assert_eq!(
159    //         cons.m[cons.m.len() - 1][0][0].to_string(),
160    //         "Fr(0x0190f922d97c8a7dcf0a142a3be27749d1c64bc22f1c556aaa24925d158cac56)"
161    //     );
162    // }
163
164    #[test]
165    fn test_hash() {
166        let b0: Fr = Fr::from_str("0").unwrap();
167        let b1: Fr = Fr::from_str("1").unwrap();
168        let b2: Fr = Fr::from_str("2").unwrap();
169        let b3: Fr = Fr::from_str("3").unwrap();
170        let b4: Fr = Fr::from_str("4").unwrap();
171        let b5: Fr = Fr::from_str("5").unwrap();
172        let b6: Fr = Fr::from_str("6").unwrap();
173        let b7: Fr = Fr::from_str("7").unwrap();
174        let b8: Fr = Fr::from_str("8").unwrap();
175        let b9: Fr = Fr::from_str("9").unwrap();
176        let b10: Fr = Fr::from_str("10").unwrap();
177        let b11: Fr = Fr::from_str("11").unwrap();
178        let b12: Fr = Fr::from_str("12").unwrap();
179        let b13: Fr = Fr::from_str("13").unwrap();
180        let b14: Fr = Fr::from_str("14").unwrap();
181        let b15: Fr = Fr::from_str("15").unwrap();
182        let b16: Fr = Fr::from_str("16").unwrap();
183
184        let poseidon = Poseidon::new();
185
186        let big_arr: Vec<Fr> = vec![b1];
187        // let mut big_arr: Vec<Fr> = Vec::new();
188        // big_arr.push(b1.clone());
189        let h = poseidon.hash(big_arr).unwrap();
190        assert_eq!(
191            h.to_string(),
192            "Fr(0x29176100eaa962bdc1fe6c654d6a3c130e96a4d1168b33848b897dc502820133)" // "18586133768512220936620570745912940619677854269274689475585506675881198879027"
193        );
194
195        let big_arr: Vec<Fr> = vec![b1, b2];
196        let h = poseidon.hash(big_arr).unwrap();
197        assert_eq!(
198            h.to_string(),
199            "Fr(0x115cc0f5e7d690413df64c6b9662e9cf2a3617f2743245519e19607a4417189a)" // "7853200120776062878684798364095072458815029376092732009249414926327459813530"
200        );
201
202        let big_arr: Vec<Fr> = vec![b1, b2, b0, b0, b0];
203        let h = poseidon.hash(big_arr).unwrap();
204        assert_eq!(
205            h.to_string(),
206            "Fr(0x024058dd1e168f34bac462b6fffe58fd69982807e9884c1c6148182319cee427)" // "1018317224307729531995786483840663576608797660851238720571059489595066344487"
207        );
208
209        let big_arr: Vec<Fr> = vec![b1, b2, b0, b0, b0, b0];
210        let h = poseidon.hash(big_arr).unwrap();
211        assert_eq!(
212            h.to_string(),
213            "Fr(0x21e82f465e00a15965e97a44fe3c30f3bf5279d8bf37d4e65765b6c2550f42a1)" // "15336558801450556532856248569924170992202208561737609669134139141992924267169"
214        );
215
216        let big_arr: Vec<Fr> = vec![b3, b4, b0, b0, b0];
217        let h = poseidon.hash(big_arr).unwrap();
218        assert_eq!(
219            h.to_string(),
220            "Fr(0x0cd93f1bab9e8c9166ef00f2a1b0e1d66d6a4145e596abe0526247747cc71214)" // "5811595552068139067952687508729883632420015185677766880877743348592482390548"
221        );
222
223        let big_arr: Vec<Fr> = vec![b3, b4, b0, b0, b0, b0];
224        let h = poseidon.hash(big_arr).unwrap();
225        assert_eq!(
226            h.to_string(),
227            "Fr(0x1b1caddfc5ea47e09bb445a7447eb9694b8d1b75a97fff58e884398c6b22825a)" // "12263118664590987767234828103155242843640892839966517009184493198782366909018"
228        );
229
230        let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6];
231        let h = poseidon.hash(big_arr).unwrap();
232        assert_eq!(
233            h.to_string(),
234            "Fr(0x2d1a03850084442813c8ebf094dea47538490a68b05f2239134a4cca2f6302e1)" // "20400040500897583745843009878988256314335038853985262692600694741116813247201"
235        );
236
237        let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14];
238        let h = poseidon.hash(big_arr).unwrap();
239        assert_eq!(
240            h.to_string(),
241            "Fr(0x1278779aaafc5ca58bf573151005830cdb4683fb26591c85a7464d4f0e527776)", // "8354478399926161176778659061636406690034081872658507739535256090879947077494"
242        );
243
244        let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0];
245        let h = poseidon.hash(big_arr).unwrap();
246        assert_eq!(
247            h.to_string(),
248            "Fr(0x0c3fbfb4d3f583df4124b4b3ac94ca3a0a1948a89fef727204d89de1c4d35693)", // "5540388656744764564518487011617040650780060800286365721923524861648744699539"
249        );
250
251        let big_arr: Vec<Fr> = vec![
252            b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0, b0, b0,
253        ];
254        let h = poseidon.hash(big_arr).unwrap();
255        assert_eq!(
256            h.to_string(),
257            "Fr(0x1a456f8563b98c9649877f38b7e36534b241c29d457d307c481cbd12b69bb721)", // "11882816200654282475720830292386643970958445617880627439994635298904836126497"
258        );
259
260        let big_arr: Vec<Fr> = vec![
261            b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16,
262        ];
263        let h = poseidon.hash(big_arr).unwrap();
264        assert_eq!(
265            h.to_string(),
266            "Fr(0x16159a551cbb66108281a48099fff949ae08afd7f1f2ec06de2ffb96b919b765)", // "9989051620750914585850546081941653841776809718687451684622678807385399211877"
267        );
268    }
269
270    #[test]
271    fn test_wrong_inputs() {
272        let b0: Fr = Fr::from_str("0").unwrap();
273        let b1: Fr = Fr::from_str("1").unwrap();
274        let b2: Fr = Fr::from_str("2").unwrap();
275
276        let poseidon = Poseidon::new();
277
278        let big_arr: Vec<Fr> = vec![
279            b1, b2, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0,
280        ];
281        poseidon.hash(big_arr).expect_err("Wrong inputs length");
282    }
283}