1pub mod constants;
2pub mod sponge;
3
4use ark_ff::Field;
5use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
6
7#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
8pub struct PoseidonConstants<F: Field> {
9 pub round_keys: Vec<F>,
10 pub mds_matrix: Vec<Vec<F>>,
11 pub num_full_rounds: usize,
12 pub num_partial_rounds: usize,
13}
14
15const CAPACITY: usize = 1; #[derive(Clone, CanonicalDeserialize, CanonicalSerialize)]
18pub struct Poseidon<F: Field, const WIDTH: usize> {
19 pub state: [F; WIDTH],
20 pub constants: PoseidonConstants<F>,
21 pub pos: usize,
22}
23
24impl<F: Field, const WIDTH: usize> Poseidon<F, WIDTH> {
25 pub fn new(constants: PoseidonConstants<F>) -> Self {
26 let state = [F::zero(); WIDTH];
27 Self {
28 state,
29 constants,
30 pos: 0,
31 }
32 }
33
34 pub fn permute(&mut self) {
35 let full_rounds_half = self.constants.num_full_rounds / 2;
36
37 for _ in 0..full_rounds_half {
39 self.full_round();
40 }
41
42 for _ in 0..self.constants.num_partial_rounds {
44 self.partial_round();
45 }
46
47 for _ in 0..full_rounds_half {
49 self.full_round();
50 }
51 }
52
53 pub fn reset(&mut self) {
54 self.state = [F::zero(); WIDTH];
55 self.pos = 0;
56 }
57
58 fn add_constants(&mut self) {
59 for i in 0..self.state.len() {
61 self.state[i] += self.constants.round_keys[i + self.pos];
62 }
63 }
64
65 fn matrix_mul(&mut self) {
67 let mut result = [F::zero(); WIDTH];
68 for (i, val) in self.constants.mds_matrix.iter().enumerate() {
69 let mut tmp = F::zero();
70 for (j, element) in self.state.iter().enumerate() {
71 tmp += val[j] * element
72 }
73 result[i] = tmp;
74 }
75
76 self.state = result;
77 }
78
79 fn full_round(&mut self) {
80 let t = self.state.len();
81 self.add_constants();
82
83 for i in 0..t {
85 self.state[i] = self.state[i].square().square() * self.state[i];
86 }
87
88 self.matrix_mul();
89
90 self.pos += self.state.len();
92 }
93
94 fn partial_round(&mut self) {
95 self.add_constants();
96
97 self.state[0] = self.state[0].square().square() * self.state[0];
99
100 self.matrix_mul();
101
102 self.pos += self.state.len();
104 }
105}