1#[cfg(feature = "alloc")]
9extern crate alloc;
10
11#[cfg(feature = "alloc")]
12use alloc::vec::Vec;
13
14use lib_q_stark_field::extension::Complex;
15use lib_q_stark_mersenne31::Mersenne31;
16
17use crate::params::{
18 Poseidon128,
19 Poseidon256,
20 PoseidonField,
21 PoseidonParams,
22};
23use crate::permutation::PoseidonPermutation;
24
25#[derive(Debug, Clone)]
26struct SpongeState {
27 permutation: PoseidonPermutation,
28 state: Vec<PoseidonField>,
29 rate: usize,
30 capacity: usize,
31 absorbed: usize,
32}
33
34impl SpongeState {
35 fn new(params: PoseidonParams) -> Self {
36 use lib_q_stark_field::PrimeCharacteristicRing;
37 let state_width = params.state_width;
38 Self {
39 permutation: PoseidonPermutation::new(params.clone()),
40 state: alloc::vec![Complex::<Mersenne31>::ZERO; state_width],
41 rate: params.rate,
42 capacity: params.capacity,
43 absorbed: 0,
44 }
45 }
46
47 fn absorb(&mut self, elements: &[PoseidonField]) {
48 for &element in elements {
49 self.state[self.absorbed] += element;
50 self.absorbed += 1;
51
52 if self.absorbed >= self.rate {
53 self.state = self.permutation.permute(self.state.clone());
54 self.absorbed = 0;
55 }
56 }
57 }
58
59 fn apply_padding_and_permute(mut self) -> Self {
61 use lib_q_stark_field::PrimeCharacteristicRing;
62
63 self.state[self.absorbed] += Complex::<Mersenne31>::ONE;
64 if self.absorbed + 1 < self.rate {
65 self.state[self.rate - 1] += Complex::<Mersenne31>::ONE;
66 }
67
68 self.state = self.permutation.permute(self.state.clone());
69 self.absorbed = 0;
70 self
71 }
72
73 fn squeeze(&mut self, num_elements: usize) -> Vec<PoseidonField> {
74 let mut output = Vec::with_capacity(num_elements);
75 let mut squeezed = 0;
76
77 while squeezed < num_elements {
78 if self.absorbed >= self.rate {
79 self.state = self.permutation.permute(self.state.clone());
80 self.absorbed = 0;
81 }
82
83 output.push(self.state[self.absorbed]);
84 self.absorbed += 1;
85 squeezed += 1;
86 }
87
88 output
89 }
90}
91
92#[derive(Debug, Clone)]
98pub struct PoseidonSponge(SpongeState);
99
100impl PoseidonSponge {
101 pub fn new(params: PoseidonParams) -> Self {
103 Self(SpongeState::new(params))
104 }
105
106 pub fn absorb(&mut self, elements: &[PoseidonField]) {
112 self.0.absorb(elements);
113 }
114
115 pub fn finish_absorbing(self) -> PoseidonSpongeSqueeze {
138 PoseidonSpongeSqueeze(self.0.apply_padding_and_permute())
139 }
140
141 pub fn finalize(self) -> Vec<PoseidonField> {
147 self.finish_absorbing().into_state()
148 }
149
150 pub fn capacity(&self) -> usize {
152 self.0.capacity
153 }
154
155 pub fn rate(&self) -> usize {
157 self.0.rate
158 }
159}
160
161#[derive(Debug, Clone)]
166pub struct PoseidonSpongeSqueeze(SpongeState);
167
168impl PoseidonSpongeSqueeze {
169 pub fn squeeze(&mut self, num_elements: usize) -> Vec<PoseidonField> {
179 self.0.squeeze(num_elements)
180 }
181
182 pub fn into_state(self) -> Vec<PoseidonField> {
184 self.0.state
185 }
186
187 pub fn capacity(&self) -> usize {
189 self.0.capacity
190 }
191
192 pub fn rate(&self) -> usize {
194 self.0.rate
195 }
196}
197
198pub trait Poseidon {
200 fn hash(&self, input: &[PoseidonField]) -> Vec<PoseidonField>;
210
211 fn hash_single(&self, input: &[PoseidonField]) -> PoseidonField {
221 self.hash(input)[0]
222 }
223}
224
225impl Poseidon for Poseidon128 {
226 fn hash(&self, input: &[PoseidonField]) -> Vec<PoseidonField> {
227 let params = Self::params();
228 let mut sponge = PoseidonSponge::new(params);
229 sponge.absorb(input);
230 let mut sponge = sponge.finish_absorbing();
231 sponge.squeeze(1)
232 }
233}
234
235impl Poseidon for Poseidon256 {
236 fn hash(&self, input: &[PoseidonField]) -> Vec<PoseidonField> {
237 let params = Self::params();
238 let mut sponge = PoseidonSponge::new(params);
239 sponge.absorb(input);
240 let mut sponge = sponge.finish_absorbing();
241 sponge.squeeze(1)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_sponge_absorb_squeeze() {
251 let params = Poseidon128::params();
252 let mut sponge = PoseidonSponge::new(params);
253 let input = alloc::vec![
254 Complex::<Mersenne31>::from(Mersenne31::new(1)),
255 Complex::<Mersenne31>::from(Mersenne31::new(2)),
256 ];
257 sponge.absorb(&input);
258 let mut sponge = sponge.finish_absorbing();
259 let output = sponge.squeeze(1);
260 assert_eq!(output.len(), 1);
261 }
262
263 #[test]
264 fn test_poseidon_hash_deterministic() {
265 use super::Poseidon;
266 let hasher = Poseidon128;
267 let input = alloc::vec![
268 Complex::<Mersenne31>::from(Mersenne31::new(1)),
269 Complex::<Mersenne31>::from(Mersenne31::new(2)),
270 ];
271 let hash1 = hasher.hash(&input);
272 let hash2 = hasher.hash(&input);
273 assert_eq!(hash1, hash2);
274 }
275}