Skip to main content

cyber_hemera/
sponge.rs

1use std::fmt;
2use std::io;
3
4use p3_goldilocks::Goldilocks;
5
6use crate::encoding::{bytes_to_rate_block, hash_to_bytes};
7use crate::params::{self, OUTPUT_BYTES, OUTPUT_BYTES_PER_ELEMENT, OUTPUT_ELEMENTS, RATE, RATE_BYTES, WIDTH};
8
9/// Domain separation tags placed in `state[capacity_start + 3]` (i.e. `state[11]`).
10const DOMAIN_HASH: u64 = 0x00;
11const DOMAIN_KEYED: u64 = 0x01;
12const DOMAIN_DERIVE_KEY_CONTEXT: u64 = 0x02;
13const DOMAIN_DERIVE_KEY_MATERIAL: u64 = 0x03;
14
15/// Index where the capacity region starts (after the rate region).
16const CAPACITY_START: usize = RATE; // 8
17
18/// A 64-byte Poseidon2 hash output (Hemera: 8 Goldilocks elements).
19#[derive(Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Hash([u8; OUTPUT_BYTES]);
21
22impl Hash {
23    /// Create a hash from a raw byte array.
24    pub const fn from_bytes(bytes: [u8; OUTPUT_BYTES]) -> Self {
25        Self(bytes)
26    }
27
28    /// Return the hash as a byte slice.
29    pub fn as_bytes(&self) -> &[u8; OUTPUT_BYTES] {
30        &self.0
31    }
32
33    /// Convert the hash to a hex string.
34    pub fn to_hex(&self) -> String {
35        let mut s = String::with_capacity(OUTPUT_BYTES * 2);
36        for byte in &self.0 {
37            use fmt::Write;
38            write!(s, "{byte:02x}").unwrap();
39        }
40        s
41    }
42}
43
44impl From<[u8; OUTPUT_BYTES]> for Hash {
45    fn from(bytes: [u8; OUTPUT_BYTES]) -> Self {
46        Self(bytes)
47    }
48}
49
50#[cfg(feature = "serde")]
51impl serde::Serialize for Hash {
52    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
53        use serde::ser::SerializeTuple;
54        let mut seq = serializer.serialize_tuple(OUTPUT_BYTES)?;
55        for byte in &self.0 {
56            seq.serialize_element(byte)?;
57        }
58        seq.end()
59    }
60}
61
62#[cfg(feature = "serde")]
63impl<'de> serde::Deserialize<'de> for Hash {
64    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
65        struct HashVisitor;
66        impl<'de> serde::de::Visitor<'de> for HashVisitor {
67            type Value = Hash;
68
69            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
70                write!(formatter, "a byte array of length {OUTPUT_BYTES}")
71            }
72
73            fn visit_seq<A: serde::de::SeqAccess<'de>>(
74                self,
75                mut seq: A,
76            ) -> Result<Hash, A::Error> {
77                let mut bytes = [0u8; OUTPUT_BYTES];
78                for (i, byte) in bytes.iter_mut().enumerate() {
79                    *byte = seq
80                        .next_element()?
81                        .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?;
82                }
83                Ok(Hash(bytes))
84            }
85        }
86        deserializer.deserialize_tuple(OUTPUT_BYTES, HashVisitor)
87    }
88}
89
90impl AsRef<[u8]> for Hash {
91    fn as_ref(&self) -> &[u8] {
92        &self.0
93    }
94}
95
96impl fmt::Display for Hash {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        for byte in &self.0 {
99            write!(f, "{byte:02x}")?;
100        }
101        Ok(())
102    }
103}
104
105impl fmt::Debug for Hash {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        write!(f, "Hash({self})")
108    }
109}
110
111/// A streaming Poseidon2 hasher.
112///
113/// Supports three modes via domain separation:
114/// - Plain hash (`new`)
115/// - Keyed hash (`new_keyed`)
116/// - Key derivation (`new_derive_key`)
117///
118/// Data is absorbed in 56-byte blocks (8 Goldilocks elements × 7 bytes each).
119#[derive(Clone)]
120pub struct Hasher {
121    state: [Goldilocks; WIDTH],
122    buf: Vec<u8>,
123    absorbed: u64,
124}
125
126impl Hasher {
127    /// Create a new hasher in plain hash mode.
128    pub fn new() -> Self {
129        let mut state = [Goldilocks::new(0); WIDTH];
130        state[CAPACITY_START + 3] = Goldilocks::new(DOMAIN_HASH);
131        Self {
132            state,
133            buf: Vec::new(),
134            absorbed: 0,
135        }
136    }
137
138    /// Create a new hasher in keyed hash mode.
139    ///
140    /// The key is absorbed as the first block (before any user data).
141    pub fn new_keyed(key: &[u8; OUTPUT_BYTES]) -> Self {
142        let mut state = [Goldilocks::new(0); WIDTH];
143        state[CAPACITY_START + 3] = Goldilocks::new(DOMAIN_KEYED);
144
145        // Absorb the key into the rate portion via the normal buffer path.
146        let mut hasher = Self {
147            state,
148            buf: Vec::new(),
149            absorbed: 0,
150        };
151        hasher.update(key.as_slice());
152        hasher
153    }
154
155    /// Create a new hasher in derive-key mode.
156    ///
157    /// First hashes the context string to produce a context key, then
158    /// sets up a second hasher seeded with that key for absorbing key material.
159    pub(crate) fn new_derive_key_context(context: &str) -> Self {
160        let mut state = [Goldilocks::new(0); WIDTH];
161        state[CAPACITY_START + 3] = Goldilocks::new(DOMAIN_DERIVE_KEY_CONTEXT);
162        let mut hasher = Self {
163            state,
164            buf: Vec::new(),
165            absorbed: 0,
166        };
167        hasher.update(context.as_bytes());
168        hasher
169    }
170
171    /// Create a derive-key hasher for the material phase, seeded by a context hash.
172    pub(crate) fn new_derive_key_material(context_hash: &Hash) -> Self {
173        let mut state = [Goldilocks::new(0); WIDTH];
174        state[CAPACITY_START + 3] = Goldilocks::new(DOMAIN_DERIVE_KEY_MATERIAL);
175
176        // Seed the rate portion with the context hash (4 elements = 32 bytes).
177        for (i, chunk) in context_hash.0.chunks(OUTPUT_BYTES_PER_ELEMENT).enumerate() {
178            let val = u64::from_le_bytes(chunk.try_into().unwrap());
179            state[i] = Goldilocks::new(val);
180        }
181        params::permute(&mut state);
182
183        Self {
184            state,
185            buf: Vec::new(),
186            absorbed: 0,
187        }
188    }
189
190    /// Absorb input data into the sponge.
191    pub fn update(&mut self, data: &[u8]) -> &mut Self {
192        self.buf.extend_from_slice(data);
193        self.absorbed += data.len() as u64;
194
195        // Process complete rate blocks.
196        while self.buf.len() >= RATE_BYTES {
197            let block_bytes: Vec<u8> = self.buf.drain(..RATE_BYTES).collect();
198            let mut rate_block = [Goldilocks::new(0); RATE];
199            bytes_to_rate_block(&block_bytes, &mut rate_block);
200            self.absorb_block(&rate_block);
201        }
202
203        self
204    }
205
206    /// Add a rate block into the state (Goldilocks field addition) and permute.
207    fn absorb_block(&mut self, block: &[Goldilocks; RATE]) {
208        for (i, block_elem) in block.iter().enumerate() {
209            self.state[i] = self.state[i] + *block_elem;
210        }
211        params::permute(&mut self.state);
212    }
213
214    /// Apply padding and produce the finalized state.
215    ///
216    /// Padding scheme (Hemera: 0x01 || 0x00*):
217    /// 1. Append 0x01 byte to remaining buffer
218    /// 2. Pad to RATE_BYTES with zeros
219    /// 3. Encode as field elements and absorb
220    /// 4. Store total byte count in capacity[2]
221    fn finalize_state(&self) -> [Goldilocks; WIDTH] {
222        let mut state = self.state;
223        let mut padded = self.buf.clone();
224
225        // Append padding marker (Hemera: 0x01).
226        padded.push(0x01);
227
228        // Pad to full rate block.
229        padded.resize(RATE_BYTES, 0x00);
230
231        // Encode and absorb the final block (Goldilocks field addition).
232        let mut rate_block = [Goldilocks::new(0); RATE];
233        bytes_to_rate_block(&padded, &mut rate_block);
234        for i in 0..RATE {
235            state[i] = state[i] + rate_block[i];
236        }
237
238        // Encode total length in capacity.
239        state[CAPACITY_START + 2] = Goldilocks::new(self.absorbed);
240
241        params::permute(&mut state);
242        state
243    }
244
245    /// Finalize and return the hash.
246    pub fn finalize(&self) -> Hash {
247        let state = self.finalize_state();
248        let output: [Goldilocks; OUTPUT_ELEMENTS] = state[..OUTPUT_ELEMENTS]
249            .try_into()
250            .unwrap();
251        Hash(hash_to_bytes(&output))
252    }
253
254    /// Finalize and return an extendable output reader (XOF mode).
255    pub fn finalize_xof(&self) -> OutputReader {
256        let state = self.finalize_state();
257        OutputReader {
258            state,
259            buffer: [0u8; OUTPUT_BYTES],
260            buffer_pos: OUTPUT_BYTES, // empty — will squeeze on first read
261        }
262    }
263}
264
265impl Default for Hasher {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271impl fmt::Debug for Hasher {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        f.debug_struct("Hasher")
274            .field("absorbed", &self.absorbed)
275            .field("buffered", &self.buf.len())
276            .finish()
277    }
278}
279
280/// An extendable-output reader that can produce arbitrary-length output.
281///
282/// Operates by repeatedly squeezing OUTPUT_BYTES from the sponge state,
283/// then permuting to produce more output.
284pub struct OutputReader {
285    state: [Goldilocks; WIDTH],
286    buffer: [u8; OUTPUT_BYTES],
287    buffer_pos: usize,
288}
289
290impl OutputReader {
291    /// Fill the provided buffer with hash output bytes.
292    pub fn fill(&mut self, output: &mut [u8]) {
293        let mut written = 0;
294        while written < output.len() {
295            if self.buffer_pos >= OUTPUT_BYTES {
296                self.squeeze();
297            }
298            let available = OUTPUT_BYTES - self.buffer_pos;
299            let needed = output.len() - written;
300            let n = available.min(needed);
301            output[written..written + n]
302                .copy_from_slice(&self.buffer[self.buffer_pos..self.buffer_pos + n]);
303            self.buffer_pos += n;
304            written += n;
305        }
306    }
307
308    /// Squeeze one block of output from the sponge.
309    fn squeeze(&mut self) {
310        let output_elems: [Goldilocks; OUTPUT_ELEMENTS] = self.state[..OUTPUT_ELEMENTS]
311            .try_into()
312            .unwrap();
313        self.buffer = hash_to_bytes(&output_elems);
314        self.buffer_pos = 0;
315        params::permute(&mut self.state);
316    }
317}
318
319impl io::Read for OutputReader {
320    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
321        self.fill(buf);
322        Ok(buf.len())
323    }
324}
325
326impl fmt::Debug for OutputReader {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        f.debug_struct("OutputReader").finish_non_exhaustive()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn hash_display_is_hex() {
338        let h = Hash([0xAB; OUTPUT_BYTES]);
339        let s = format!("{h}");
340        assert_eq!(s.len(), OUTPUT_BYTES * 2);
341        assert!(s.chars().all(|c| c.is_ascii_hexdigit()));
342    }
343
344    #[test]
345    fn empty_hash_is_not_zero() {
346        let h = Hasher::new().finalize();
347        assert_ne!(h.0, [0u8; OUTPUT_BYTES]);
348    }
349
350    #[test]
351    fn different_inputs_different_hashes() {
352        let h1 = Hasher::new().update(b"a").finalize();
353        let h2 = Hasher::new().update(b"b").finalize();
354        assert_ne!(h1, h2);
355    }
356
357    #[test]
358    fn streaming_consistency() {
359        let data = b"hello world, this is a test of streaming consistency!";
360        let one_shot = {
361            let mut h = Hasher::new();
362            h.update(data);
363            h.finalize()
364        };
365        let streamed = {
366            let mut h = Hasher::new();
367            h.update(&data[..5]);
368            h.update(&data[5..20]);
369            h.update(&data[20..]);
370            h.finalize()
371        };
372        assert_eq!(one_shot, streamed);
373    }
374
375    #[test]
376    fn streaming_across_rate_boundary() {
377        // 56 bytes = exactly one rate block, so 100 bytes crosses a boundary.
378        let data = vec![0x42u8; 100];
379        let one_shot = {
380            let mut h = Hasher::new();
381            h.update(&data);
382            h.finalize()
383        };
384        let byte_at_a_time = {
385            let mut h = Hasher::new();
386            for b in &data {
387                h.update(std::slice::from_ref(b));
388            }
389            h.finalize()
390        };
391        assert_eq!(one_shot, byte_at_a_time);
392    }
393
394    #[test]
395    fn domain_separation_hash_vs_keyed() {
396        let data = b"test data";
397        let plain = Hasher::new().update(data).finalize();
398        let keyed = Hasher::new_keyed(&[0u8; OUTPUT_BYTES]).update(data).finalize();
399        assert_ne!(plain, keyed);
400    }
401
402    #[test]
403    fn domain_separation_hash_vs_derive_key() {
404        let data = b"test material";
405        let plain = Hasher::new().update(data).finalize();
406        let ctx_hasher = Hasher::new_derive_key_context("test context");
407        let ctx_hash = ctx_hasher.finalize();
408        let derived = Hasher::new_derive_key_material(&ctx_hash)
409            .update(data)
410            .finalize();
411        assert_ne!(plain, derived);
412    }
413
414    #[test]
415    fn xof_first_32_match_finalize() {
416        let data = b"xof test";
417        let hash = Hasher::new().update(data).finalize();
418        let mut xof = Hasher::new().update(data).finalize_xof();
419        let mut xof_bytes = [0u8; OUTPUT_BYTES];
420        xof.fill(&mut xof_bytes);
421        assert_eq!(hash.as_bytes(), &xof_bytes);
422    }
423
424    #[test]
425    fn xof_produces_more_than_32_bytes() {
426        let mut xof = Hasher::new().update(b"xof").finalize_xof();
427        let mut out = [0u8; 128];
428        xof.fill(&mut out);
429        // Not all zeros.
430        assert_ne!(out, [0u8; 128]);
431        // Different 32-byte blocks (with overwhelming probability).
432        assert_ne!(out[..OUTPUT_BYTES], out[OUTPUT_BYTES..OUTPUT_BYTES * 2]);
433    }
434
435    #[test]
436    fn xof_read_trait() {
437        use std::io::Read;
438        let mut xof = Hasher::new().update(b"read trait").finalize_xof();
439        let mut buf = [0u8; 64];
440        let n = xof.read(&mut buf).unwrap();
441        assert_eq!(n, 64);
442    }
443
444    #[test]
445    fn keyed_hash_different_keys() {
446        let data = b"same data";
447        let h1 = Hasher::new_keyed(&[0u8; OUTPUT_BYTES]).update(data).finalize();
448        let h2 = Hasher::new_keyed(&[1u8; OUTPUT_BYTES]).update(data).finalize();
449        assert_ne!(h1, h2);
450    }
451}