Skip to main content

miden_stateful_hasher/
serializing_sponge.rs

1//! Serializing stateful sponge.
2//!
3//! This module provides [`SerializingStatefulSponge`] which serializes field elements
4//! to binary before absorption into an inner `StatefulHasher`.
5
6use core::mem::size_of;
7
8use p3_field::Field;
9
10use crate::{Alignable, StatefulHasher};
11
12/// An adapter that serializes field elements to binary and delegates to an inner `StatefulHasher`.
13///
14/// This mirrors `SerializingHasher`'s conversions from fields to bytes/u32/u64 streams,
15/// but implements the `StatefulHasher` interface by delegating to an inner stateful hasher
16/// that operates on binary data.
17///
18/// Unlike `ChainingHasher` (which uses chaining mode `H(state || input)`), this adapter
19/// preserves proper sponge absorption semantics by directly calling the inner hasher's
20/// `absorb_into` method.
21#[derive(Copy, Clone, Debug)]
22pub struct SerializingStatefulSponge<Inner> {
23    inner: Inner,
24}
25
26impl<Inner> SerializingStatefulSponge<Inner> {
27    pub const fn new(inner: Inner) -> Self {
28        Self { inner }
29    }
30}
31
32// -----------------------------------------------------------------------------
33// Scalar implementations: F -> [B; OUT]
34// The digest type [B; OUT] distinguishes these from parallel implementations.
35// -----------------------------------------------------------------------------
36
37// Scalar field -> u8 based inner
38impl<F, Inner, const OUT: usize> StatefulHasher<F, [u8; OUT]> for SerializingStatefulSponge<Inner>
39where
40    F: Field,
41    Inner: StatefulHasher<u8, [u8; OUT]>,
42{
43    type State = Inner::State;
44
45    fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = F>) {
46        self.inner.absorb_into(state, F::into_byte_stream(input));
47    }
48
49    fn squeeze(&self, state: &Self::State) -> [u8; OUT] {
50        self.inner.squeeze(state)
51    }
52}
53
54// Scalar field -> u32 based inner
55impl<F, Inner, const OUT: usize> StatefulHasher<F, [u32; OUT]> for SerializingStatefulSponge<Inner>
56where
57    F: Field,
58    Inner: StatefulHasher<u32, [u32; OUT]>,
59{
60    type State = Inner::State;
61
62    fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = F>) {
63        self.inner.absorb_into(state, F::into_u32_stream(input));
64    }
65
66    fn squeeze(&self, state: &Self::State) -> [u32; OUT] {
67        self.inner.squeeze(state)
68    }
69}
70
71// Scalar field -> u64 based inner
72impl<F, Inner, const OUT: usize> StatefulHasher<F, [u64; OUT]> for SerializingStatefulSponge<Inner>
73where
74    F: Field,
75    Inner: StatefulHasher<u64, [u64; OUT]>,
76{
77    type State = Inner::State;
78
79    fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = F>) {
80        self.inner.absorb_into(state, F::into_u64_stream(input));
81    }
82
83    fn squeeze(&self, state: &Self::State) -> [u64; OUT] {
84        self.inner.squeeze(state)
85    }
86}
87
88// -----------------------------------------------------------------------------
89// Parallel implementations: [F; M] -> [[B; M]; OUT]
90// The digest type [[B; M]; OUT] is structurally different from [B; OUT],
91// which prevents coherence conflicts with scalar implementations.
92// -----------------------------------------------------------------------------
93
94// Parallel [F; M] -> [u8; M] based inner
95impl<F, Inner, const M: usize, const OUT: usize> StatefulHasher<[F; M], [[u8; M]; OUT]>
96    for SerializingStatefulSponge<Inner>
97where
98    F: Field,
99    Inner: StatefulHasher<[u8; M], [[u8; M]; OUT]>,
100{
101    type State = Inner::State;
102
103    fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = [F; M]>) {
104        self.inner.absorb_into(state, F::into_parallel_byte_streams(input));
105    }
106
107    fn squeeze(&self, state: &Self::State) -> [[u8; M]; OUT] {
108        self.inner.squeeze(state)
109    }
110}
111
112// Parallel [F; M] -> [u32; M] based inner
113impl<F, Inner, const M: usize, const OUT: usize> StatefulHasher<[F; M], [[u32; M]; OUT]>
114    for SerializingStatefulSponge<Inner>
115where
116    F: Field,
117    Inner: StatefulHasher<[u32; M], [[u32; M]; OUT]>,
118{
119    type State = Inner::State;
120
121    fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = [F; M]>) {
122        self.inner.absorb_into(state, F::into_parallel_u32_streams(input));
123    }
124
125    fn squeeze(&self, state: &Self::State) -> [[u32; M]; OUT] {
126        self.inner.squeeze(state)
127    }
128}
129
130// Parallel [F; M] -> [u64; M] based inner
131impl<F, Inner, const M: usize, const OUT: usize> StatefulHasher<[F; M], [[u64; M]; OUT]>
132    for SerializingStatefulSponge<Inner>
133where
134    F: Field,
135    Inner: StatefulHasher<[u64; M], [[u64; M]; OUT]>,
136{
137    type State = Inner::State;
138
139    fn absorb_into(&self, state: &mut Self::State, input: impl IntoIterator<Item = [F; M]>) {
140        self.inner.absorb_into(state, F::into_parallel_u64_streams(input));
141    }
142
143    fn squeeze(&self, state: &Self::State) -> [[u64; M]; OUT] {
144        self.inner.squeeze(state)
145    }
146}
147
148// -----------------------------------------------------------------------------
149// Alignable implementations for SerializingStatefulSponge
150// -----------------------------------------------------------------------------
151
152/// Compute alignment for a serializing wrapper that converts field elements to binary items.
153///
154/// Given:
155/// - `field_bytes`: The field's byte size (`F::NUM_BYTES`)
156/// - `item_bytes`: The inner item's byte size (1 for u8, 4 for u32, 8 for u64)
157/// - `inner_alignment`: The inner hasher's alignment in items (e.g., sponge rate)
158///
159/// Returns the alignment in field elements that corresponds to the inner alignment.
160///
161/// The formula ensures that serializing `alignment` field elements produces exactly
162/// `inner_alignment` inner items (or a multiple thereof).
163const fn compute_field_alignment(
164    field_bytes: usize,
165    item_bytes: usize,
166    inner_alignment: usize,
167) -> usize {
168    // We need the smallest number of field elements that, when serialized,
169    // produce a byte count divisible by the inner hasher's block size.
170    // This is lcm(field_bytes, inner_bytes) / field_bytes.
171    //
172    // Example: 4-byte field, inner rate = 3 u64s (24 bytes)
173    // lcm(4, 24) = 24, so alignment = 24/4 = 6 fields
174    // Verify: 6 fields × 4 bytes = 24 bytes = 3 u64s ✓
175    //
176    // When field_bytes > inner_bytes, alignment is often 1:
177    // Example: 32-byte field, inner rate = 2 u64s (16 bytes)
178    // lcm(32, 16) = 32, so alignment = 32/32 = 1
179    // Each field spans 2 complete blocks, so every field ends aligned.
180    let inner_bytes = inner_alignment * item_bytes;
181
182    // gcd via Euclidean algorithm
183    let mut a = field_bytes;
184    let mut b = inner_bytes;
185    while b != 0 {
186        let t = b;
187        b = a % b;
188        a = t;
189    }
190    inner_bytes / a
191}
192
193impl<F, Inner, T> Alignable<F, T> for SerializingStatefulSponge<Inner>
194where
195    F: Field,
196    Inner: Alignable<T, T>,
197{
198    const ALIGNMENT: usize =
199        compute_field_alignment(F::NUM_BYTES, size_of::<T>(), Inner::ALIGNMENT);
200}
201
202#[cfg(test)]
203mod tests {
204    use alloc::vec::Vec;
205
206    use p3_bn254::Bn254;
207    use p3_field::Field;
208    use p3_goldilocks::Goldilocks;
209    use p3_mersenne_31::Mersenne31;
210
211    use super::*;
212    use crate::{StatefulSponge, testing::MockBinaryPermutation};
213
214    /// Verifies implicit zero-padding equals explicit zeros for serialized fields.
215    fn test_alignment_semantic<F: Field, const WIDTH: usize, const RATE: usize, const OUT: usize>()
216    where
217        SerializingStatefulSponge<
218            StatefulSponge<MockBinaryPermutation<u64, WIDTH>, WIDTH, RATE, OUT>,
219        >: Alignable<F, u64>,
220        [u64; WIDTH]: Default,
221    {
222        let inner = StatefulSponge::<_, WIDTH, RATE, OUT>::new(
223            MockBinaryPermutation::<u64, WIDTH>::default(),
224        );
225        let hasher = SerializingStatefulSponge::new(inner);
226
227        let alignment = <SerializingStatefulSponge<
228            StatefulSponge<MockBinaryPermutation<u64, WIDTH>, WIDTH, RATE, OUT>,
229        > as Alignable<F, u64>>::ALIGNMENT;
230
231        for input_len in 1..=(alignment * 3) {
232            let input: Vec<F> = (1..=input_len).map(|i| F::from_usize(i)).collect();
233
234            let mut state_unpadded = [0u64; WIDTH];
235            StatefulHasher::<F, [u64; OUT]>::absorb_into(
236                &hasher,
237                &mut state_unpadded,
238                input.iter().copied(),
239            );
240            let output_unpadded: [u64; OUT] =
241                StatefulHasher::<F, [u64; OUT]>::squeeze(&hasher, &state_unpadded);
242
243            let remainder = input_len % alignment;
244            let zeros_needed = if remainder == 0 { 0 } else { alignment - remainder };
245            let mut padded_input = input.clone();
246            padded_input.extend(core::iter::repeat_n(F::ZERO, zeros_needed));
247
248            let mut state_padded = [0u64; WIDTH];
249            StatefulHasher::<F, [u64; OUT]>::absorb_into(
250                &hasher,
251                &mut state_padded,
252                padded_input.iter().copied(),
253            );
254            let output_padded: [u64; OUT] =
255                StatefulHasher::<F, [u64; OUT]>::squeeze(&hasher, &state_padded);
256
257            assert_eq!(output_unpadded, output_padded);
258        }
259    }
260
261    #[test]
262    fn alignment_semantic() {
263        // Different field sizes exercise different alignment calculations
264        test_alignment_semantic::<Mersenne31, 16, 8, 4>(); // 4 bytes -> alignment 16
265        test_alignment_semantic::<Goldilocks, 16, 8, 4>(); // 8 bytes -> alignment 8
266        test_alignment_semantic::<Bn254, 16, 8, 4>(); // 32 bytes -> alignment 2
267    }
268
269    #[test]
270    fn test_compute_field_alignment() {
271        // 4-byte field (e.g., Mersenne31) to u32 (4 bytes), inner alignment 8
272        // inner_bytes = 32, gcd(4, 32) = 4, alignment = 32/4 = 8
273        assert_eq!(compute_field_alignment(4, 4, 8), 8);
274
275        // 4-byte field to u64 (8 bytes), inner alignment 4
276        // inner_bytes = 32, gcd(4, 32) = 4, alignment = 32/4 = 8
277        assert_eq!(compute_field_alignment(4, 8, 4), 8);
278
279        // 8-byte field (e.g., Goldilocks) to u32 (4 bytes), inner alignment 8
280        // inner_bytes = 32, gcd(8, 32) = 8, alignment = 32/8 = 4
281        assert_eq!(compute_field_alignment(8, 4, 8), 4);
282
283        // 8-byte field to u64 (8 bytes), inner alignment 4
284        // inner_bytes = 32, gcd(8, 32) = 8, alignment = 32/8 = 4
285        assert_eq!(compute_field_alignment(8, 8, 4), 4);
286
287        // 32-byte field (e.g., Bn254) to u64 (8 bytes), inner alignment 2
288        // inner_bytes = 16, gcd(32, 16) = 16, alignment = 16/16 = 1
289        assert_eq!(compute_field_alignment(32, 8, 2), 1);
290    }
291
292    #[test]
293    fn test_compute_field_alignment_non_power_of_2_rate() {
294        // 4-byte field (e.g., Mersenne31) to u64 (8 bytes), rate 3
295        // inner_bytes = 24, gcd(4, 24) = 4, alignment = 24/4 = 6
296        // Verify: 6 fields * 4 bytes = 24 bytes = 3 u64s ✓
297        assert_eq!(compute_field_alignment(4, 8, 3), 6);
298
299        // 8-byte field (e.g., Goldilocks) to u32 (4 bytes), rate 3
300        // inner_bytes = 12, gcd(8, 12) = 4, alignment = 12/4 = 3
301        // Verify: 3 fields * 8 bytes = 24 bytes = 6 u32s = 2 * rate ✓
302        assert_eq!(compute_field_alignment(8, 4, 3), 3);
303
304        // 4-byte field to u32 (4 bytes), rate 7
305        // inner_bytes = 28, gcd(4, 28) = 4, alignment = 28/4 = 7
306        assert_eq!(compute_field_alignment(4, 4, 7), 7);
307
308        // 8-byte field to u64 (8 bytes), rate 5
309        // inner_bytes = 40, gcd(8, 40) = 8, alignment = 40/8 = 5
310        assert_eq!(compute_field_alignment(8, 8, 5), 5);
311    }
312}