anonklub_poseidon/
lib.rs

1pub mod constants;
2pub mod sponge;
3
4use ark_ff::Field;
5use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
6
7#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
8pub struct PoseidonConstants<F: Field> {
9    pub round_keys: Vec<F>,
10    pub mds_matrix: Vec<Vec<F>>,
11    pub num_full_rounds: usize,
12    pub num_partial_rounds: usize,
13}
14
15const CAPACITY: usize = 1; // We fix the capacity to be one.
16
17#[derive(Clone, CanonicalDeserialize, CanonicalSerialize)]
18pub struct Poseidon<F: Field, const WIDTH: usize> {
19    pub state: [F; WIDTH],
20    pub constants: PoseidonConstants<F>,
21    pub pos: usize,
22}
23
24impl<F: Field, const WIDTH: usize> Poseidon<F, WIDTH> {
25    pub fn new(constants: PoseidonConstants<F>) -> Self {
26        let state = [F::zero(); WIDTH];
27        Self {
28            state,
29            constants,
30            pos: 0,
31        }
32    }
33
34    pub fn permute(&mut self) {
35        let full_rounds_half = self.constants.num_full_rounds / 2;
36
37        // First half of full rounds
38        for _ in 0..full_rounds_half {
39            self.full_round();
40        }
41
42        // Partial rounds
43        for _ in 0..self.constants.num_partial_rounds {
44            self.partial_round();
45        }
46
47        // Second half of full rounds
48        for _ in 0..full_rounds_half {
49            self.full_round();
50        }
51    }
52
53    pub fn reset(&mut self) {
54        self.state = [F::zero(); WIDTH];
55        self.pos = 0;
56    }
57
58    fn add_constants(&mut self) {
59        // Add round constants
60        for i in 0..self.state.len() {
61            self.state[i] += self.constants.round_keys[i + self.pos];
62        }
63    }
64
65    // MDS matrix multiplication
66    fn matrix_mul(&mut self) {
67        let mut result = [F::zero(); WIDTH];
68        for (i, val) in self.constants.mds_matrix.iter().enumerate() {
69            let mut tmp = F::zero();
70            for (j, element) in self.state.iter().enumerate() {
71                tmp += val[j] * element
72            }
73            result[i] = tmp;
74        }
75
76        self.state = result;
77    }
78
79    fn full_round(&mut self) {
80        let t = self.state.len();
81        self.add_constants();
82
83        // S-boxes
84        for i in 0..t {
85            self.state[i] = self.state[i].square().square() * self.state[i];
86        }
87
88        self.matrix_mul();
89
90        // Update the position of the round constants that are added
91        self.pos += self.state.len();
92    }
93
94    fn partial_round(&mut self) {
95        self.add_constants();
96
97        // S-box
98        self.state[0] = self.state[0].square().square() * self.state[0];
99
100        self.matrix_mul();
101
102        // Update the position of the round constants that are added
103        self.pos += self.state.len();
104    }
105}