1use core::mem::size_of;
7
8use p3_field::Field;
9
10use crate::{Alignable, StatefulHasher};
11
12#[derive(Copy, Clone, Debug)]
22pub struct SerializingStatefulSponge<Inner> {
23 inner: Inner,
24}
25
26impl<Inner> SerializingStatefulSponge<Inner> {
27 pub const fn new(inner: Inner) -> Self {
28 Self { inner }
29 }
30}
31
32impl<F, Inner, const OUT: usize> StatefulHasher<F, [u8; OUT]> for SerializingStatefulSponge<Inner>
39where
40 F: Field,
41 Inner: StatefulHasher<u8, [u8; OUT]>,
42{
43 type State = Inner::State;
44
45 fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = F>) {
46 self.inner.absorb_into(state, F::into_byte_stream(input));
47 }
48
49 fn squeeze(&self, state: &Self::State) -> [u8; OUT] {
50 self.inner.squeeze(state)
51 }
52}
53
54impl<F, Inner, const OUT: usize> StatefulHasher<F, [u32; OUT]> for SerializingStatefulSponge<Inner>
56where
57 F: Field,
58 Inner: StatefulHasher<u32, [u32; OUT]>,
59{
60 type State = Inner::State;
61
62 fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = F>) {
63 self.inner.absorb_into(state, F::into_u32_stream(input));
64 }
65
66 fn squeeze(&self, state: &Self::State) -> [u32; OUT] {
67 self.inner.squeeze(state)
68 }
69}
70
71impl<F, Inner, const OUT: usize> StatefulHasher<F, [u64; OUT]> for SerializingStatefulSponge<Inner>
73where
74 F: Field,
75 Inner: StatefulHasher<u64, [u64; OUT]>,
76{
77 type State = Inner::State;
78
79 fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = F>) {
80 self.inner.absorb_into(state, F::into_u64_stream(input));
81 }
82
83 fn squeeze(&self, state: &Self::State) -> [u64; OUT] {
84 self.inner.squeeze(state)
85 }
86}
87
88impl<F, Inner, const M: usize, const OUT: usize> StatefulHasher<[F; M], [[u8; M]; OUT]>
96 for SerializingStatefulSponge<Inner>
97where
98 F: Field,
99 Inner: StatefulHasher<[u8; M], [[u8; M]; OUT]>,
100{
101 type State = Inner::State;
102
103 fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = [F; M]>) {
104 self.inner.absorb_into(state, F::into_parallel_byte_streams(input));
105 }
106
107 fn squeeze(&self, state: &Self::State) -> [[u8; M]; OUT] {
108 self.inner.squeeze(state)
109 }
110}
111
112impl<F, Inner, const M: usize, const OUT: usize> StatefulHasher<[F; M], [[u32; M]; OUT]>
114 for SerializingStatefulSponge<Inner>
115where
116 F: Field,
117 Inner: StatefulHasher<[u32; M], [[u32; M]; OUT]>,
118{
119 type State = Inner::State;
120
121 fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = [F; M]>) {
122 self.inner.absorb_into(state, F::into_parallel_u32_streams(input));
123 }
124
125 fn squeeze(&self, state: &Self::State) -> [[u32; M]; OUT] {
126 self.inner.squeeze(state)
127 }
128}
129
130impl<F, Inner, const M: usize, const OUT: usize> StatefulHasher<[F; M], [[u64; M]; OUT]>
132 for SerializingStatefulSponge<Inner>
133where
134 F: Field,
135 Inner: StatefulHasher<[u64; M], [[u64; M]; OUT]>,
136{
137 type State = Inner::State;
138
139 fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = [F; M]>) {
140 self.inner.absorb_into(state, F::into_parallel_u64_streams(input));
141 }
142
143 fn squeeze(&self, state: &Self::State) -> [[u64; M]; OUT] {
144 self.inner.squeeze(state)
145 }
146}
147
148const fn compute_field_alignment(
164 field_bytes: usize,
165 item_bytes: usize,
166 inner_alignment: usize,
167) -> usize {
168 let inner_bytes = inner_alignment * item_bytes;
181
182 let mut a = field_bytes;
184 let mut b = inner_bytes;
185 while b != 0 {
186 let t = b;
187 b = a % b;
188 a = t;
189 }
190 inner_bytes / a
191}
192
193impl<F, Inner, T> Alignable<F, T> for SerializingStatefulSponge<Inner>
194where
195 F: Field,
196 Inner: Alignable<T, T>,
197{
198 const ALIGNMENT: usize =
199 compute_field_alignment(F::NUM_BYTES, size_of::<T>(), Inner::ALIGNMENT);
200}
201
202#[cfg(test)]
203mod tests {
204 use alloc::vec::Vec;
205
206 use p3_bn254::Bn254;
207 use p3_field::Field;
208 use p3_goldilocks::Goldilocks;
209 use p3_mersenne_31::Mersenne31;
210
211 use super::*;
212 use crate::{StatefulSponge, testing::MockBinaryPermutation};
213
214 fn test_alignment_semantic<F: Field, const WIDTH: usize, const RATE: usize, const OUT: usize>()
216 where
217 SerializingStatefulSponge<
218 StatefulSponge<MockBinaryPermutation<u64, WIDTH>, WIDTH, RATE, OUT>,
219 >: Alignable<F, u64>,
220 [u64; WIDTH]: Default,
221 {
222 let inner = StatefulSponge::<_, WIDTH, RATE, OUT>::new(
223 MockBinaryPermutation::<u64, WIDTH>::default(),
224 );
225 let hasher = SerializingStatefulSponge::new(inner);
226
227 let alignment = <SerializingStatefulSponge<
228 StatefulSponge<MockBinaryPermutation<u64, WIDTH>, WIDTH, RATE, OUT>,
229 > as Alignable<F, u64>>::ALIGNMENT;
230
231 for input_len in 1..=(alignment * 3) {
232 let input: Vec<F> = (1..=input_len).map(|i| F::from_usize(i)).collect();
233
234 let mut state_unpadded = [0u64; WIDTH];
235 StatefulHasher::<F, [u64; OUT]>::absorb_into(
236 &hasher,
237 &mut state_unpadded,
238 input.iter().copied(),
239 );
240 let output_unpadded: [u64; OUT] =
241 StatefulHasher::<F, [u64; OUT]>::squeeze(&hasher, &state_unpadded);
242
243 let remainder = input_len % alignment;
244 let zeros_needed = if remainder == 0 { 0 } else { alignment - remainder };
245 let mut padded_input = input.clone();
246 padded_input.extend(core::iter::repeat_n(F::ZERO, zeros_needed));
247
248 let mut state_padded = [0u64; WIDTH];
249 StatefulHasher::<F, [u64; OUT]>::absorb_into(
250 &hasher,
251 &mut state_padded,
252 padded_input.iter().copied(),
253 );
254 let output_padded: [u64; OUT] =
255 StatefulHasher::<F, [u64; OUT]>::squeeze(&hasher, &state_padded);
256
257 assert_eq!(output_unpadded, output_padded);
258 }
259 }
260
261 #[test]
262 fn alignment_semantic() {
263 test_alignment_semantic::<Mersenne31, 16, 8, 4>(); test_alignment_semantic::<Goldilocks, 16, 8, 4>(); test_alignment_semantic::<Bn254, 16, 8, 4>(); }
268
269 #[test]
270 fn test_compute_field_alignment() {
271 assert_eq!(compute_field_alignment(4, 4, 8), 8);
274
275 assert_eq!(compute_field_alignment(4, 8, 4), 8);
278
279 assert_eq!(compute_field_alignment(8, 4, 8), 4);
282
283 assert_eq!(compute_field_alignment(8, 8, 4), 4);
286
287 assert_eq!(compute_field_alignment(32, 8, 2), 1);
290 }
291
292 #[test]
293 fn test_compute_field_alignment_non_power_of_2_rate() {
294 assert_eq!(compute_field_alignment(4, 8, 3), 6);
298
299 assert_eq!(compute_field_alignment(8, 4, 3), 3);
303
304 assert_eq!(compute_field_alignment(4, 4, 7), 7);
307
308 assert_eq!(compute_field_alignment(8, 8, 5), 5);
311 }
312}