Skip to main content

trident/package/
poseidon2.rs

1//! Poseidon2 hash function over the Goldilocks field (p = 2^64 - 2^32 + 1).
2//!
3//! Implements the Poseidon2 permutation (Grassi et al., 2023) with:
4//!   - State width t = 8, rate = 4, capacity = 4
5//!   - S-box x^7
6//!   - 8 full rounds (4 + 4) and 22 partial rounds
7//!   - Round constants derived deterministically from BLAKE3
8
9/// Goldilocks prime: p = 2^64 - 2^32 + 1
10const P: u64 = crate::field::goldilocks::MODULUS;
11
12/// Poseidon2 state width.
13const T: usize = 8;
14/// Rate (number of input elements absorbed per permutation call).
15const RATE: usize = 4;
16/// Number of full rounds.
17const R_F: usize = 8;
18/// Number of partial rounds.
19const R_P: usize = 22;
20/// S-box exponent: gcd(7, p-1) = 1 for the Goldilocks prime.
21#[cfg(test)]
22const ALPHA: u64 = 7;
23
24/// Internal diagonal constants: d_i = 1 + 2^i.
25const DIAG: [u64; T] = [2, 3, 5, 9, 17, 33, 65, 129];
26
27// ---------------------------------------------------------------------------
28// Goldilocks field element
29// ---------------------------------------------------------------------------
30
31/// A field element in the Goldilocks field (u64 modulo `P`).
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub struct GoldilocksField(pub u64);
34
35impl GoldilocksField {
36    pub const ZERO: Self = Self(0);
37    pub const ONE: Self = Self(1);
38
39    /// Canonical constructor -- reduces `v` modulo `P`.
40    #[inline]
41    pub fn new(v: u64) -> Self {
42        Self(v % P)
43    }
44
45    /// Reduce a u128 value modulo P using 2^64 = 2^32 - 1 (mod P).
46    #[inline]
47    fn reduce128(x: u128) -> Self {
48        // Goldilocks: P = 2^64 - 2^32 + 1, so 2^64 ≡ 2^32 - 1 (mod P).
49        // Split x = lo + hi * 2^64, then x ≡ lo + hi * (2^32 - 1) (mod P).
50        let lo = x as u64;
51        let hi = (x >> 64) as u64;
52        let hi_shifted = (hi as u128) * ((1u128 << 32) - 1);
53        let sum = lo as u128 + hi_shifted;
54        // sum fits in ~97 bits max. Split again.
55        let lo2 = sum as u64;
56        let hi2 = (sum >> 64) as u64;
57        if hi2 == 0 {
58            Self(if lo2 >= P { lo2 - P } else { lo2 })
59        } else {
60            // hi2 is at most ~2^32, so r fits in ~65 bits.
61            let r = lo2 as u128 + (hi2 as u128) * ((1u128 << 32) - 1);
62            // r may exceed 2^64, so one more split may be needed.
63            let lo3 = r as u64;
64            let hi3 = (r >> 64) as u64;
65            if hi3 == 0 {
66                Self(if lo3 >= P { lo3 - P } else { lo3 })
67            } else {
68                // hi3 is at most 1, so final value fits in u64.
69                let v = lo3.wrapping_add(hi3.wrapping_mul(u32::MAX as u64));
70                Self(if v >= P { v - P } else { v })
71            }
72        }
73    }
74
75    #[inline]
76    pub fn add(self, rhs: Self) -> Self {
77        let (sum, carry) = self.0.overflowing_add(rhs.0);
78        if carry {
79            let r = sum + (u32::MAX as u64);
80            Self(if r >= P { r - P } else { r })
81        } else {
82            Self(if sum >= P { sum - P } else { sum })
83        }
84    }
85
86    #[inline]
87    pub fn sub(self, rhs: Self) -> Self {
88        if self.0 >= rhs.0 {
89            Self(self.0 - rhs.0)
90        } else {
91            Self(P - rhs.0 + self.0)
92        }
93    }
94
95    #[inline]
96    pub fn mul(self, rhs: Self) -> Self {
97        Self::reduce128((self.0 as u128) * (rhs.0 as u128))
98    }
99
100    /// Exponentiation via square-and-multiply.
101    pub fn pow(self, mut exp: u64) -> Self {
102        let mut base = self;
103        let mut acc = Self::ONE;
104        while exp > 0 {
105            if exp & 1 == 1 {
106                acc = acc.mul(base);
107            }
108            base = base.mul(base);
109            exp >>= 1;
110        }
111        acc
112    }
113
114    /// The Poseidon2 S-box: x^7.
115    #[inline]
116    pub fn sbox(self) -> Self {
117        let x2 = self.mul(self);
118        let x3 = x2.mul(self);
119        let x6 = x3.mul(x3);
120        x6.mul(self)
121    }
122}
123
124// ---------------------------------------------------------------------------
125// Round-constant generation
126// ---------------------------------------------------------------------------
127
128const TOTAL_ROUNDS: usize = R_F + R_P;
129
130/// Generate the round constant for (`round`, `element`) deterministically.
131fn round_constant(round: usize, element: usize) -> GoldilocksField {
132    let tag = format!("Poseidon2-Goldilocks-t8-RF8-RP22-{round}-{element}");
133    let digest = blake3::hash(tag.as_bytes());
134    let bytes: [u8; 8] = digest.as_bytes()[..8].try_into().unwrap_or([0u8; 8]);
135    GoldilocksField::new(u64::from_le_bytes(bytes))
136}
137
138/// Generate all round constants: T per full round, 1 per partial round.
139fn generate_all_constants() -> Vec<GoldilocksField> {
140    let mut constants = Vec::new();
141    for r in 0..TOTAL_ROUNDS {
142        let is_full = r < R_F / 2 || r >= R_F / 2 + R_P;
143        if is_full {
144            for e in 0..T {
145                constants.push(round_constant(r, e));
146            }
147        } else {
148            constants.push(round_constant(r, 0));
149        }
150    }
151    constants
152}
153
154/// Cached round constants, computed once on first access.
155fn cached_round_constants() -> &'static [GoldilocksField] {
156    static CONSTANTS: std::sync::OnceLock<Vec<GoldilocksField>> = std::sync::OnceLock::new();
157    CONSTANTS.get_or_init(generate_all_constants)
158}
159
160// ---------------------------------------------------------------------------
161// Poseidon2 state & permutation
162// ---------------------------------------------------------------------------
163
164/// The Poseidon2 internal state (8 Goldilocks elements).
165pub struct Poseidon2Sponge {
166    pub state: [GoldilocksField; T],
167}
168
169impl Poseidon2Sponge {
170    pub fn new() -> Self {
171        Self {
172            state: [GoldilocksField::ZERO; T],
173        }
174    }
175
176    /// Apply the S-box to every element (full round).
177    #[inline]
178    fn full_sbox(&mut self) {
179        for s in self.state.iter_mut() {
180            *s = s.sbox();
181        }
182    }
183
184    /// Apply the S-box to element 0 only (partial round).
185    #[inline]
186    fn partial_sbox(&mut self) {
187        self.state[0] = self.state[0].sbox();
188    }
189
190    /// External linear layer: circ(2,1,1,...,1).
191    /// new[i] = 2*state[i] + sum(state).
192    fn external_linear(&mut self) {
193        let sum = self
194            .state
195            .iter()
196            .fold(GoldilocksField::ZERO, |a, &b| a.add(b));
197        for s in self.state.iter_mut() {
198            *s = s.add(sum); // state[i] + sum(all) = 2*state[i] + sum(others)
199        }
200    }
201
202    /// Internal linear layer: diag(d_0,...,d_7) + ones_matrix.
203    /// new[i] = d_i * state[i] + sum(state).
204    fn internal_linear(&mut self) {
205        let sum = self
206            .state
207            .iter()
208            .fold(GoldilocksField::ZERO, |a, &b| a.add(b));
209        for (i, s) in self.state.iter_mut().enumerate() {
210            *s = GoldilocksField(DIAG[i]).mul(*s).add(sum);
211        }
212    }
213
214    /// Full Poseidon2 permutation (in-place).
215    pub fn permutation(&mut self) {
216        let constants = cached_round_constants();
217        let mut ci = 0;
218
219        // First R_F/2 full rounds
220        for _ in 0..R_F / 2 {
221            for s in self.state.iter_mut() {
222                *s = s.add(constants[ci]);
223                ci += 1;
224            }
225            self.full_sbox();
226            self.external_linear();
227        }
228
229        // R_P partial rounds
230        for _ in 0..R_P {
231            self.state[0] = self.state[0].add(constants[ci]);
232            ci += 1;
233            self.partial_sbox();
234            self.internal_linear();
235        }
236
237        // Last R_F/2 full rounds
238        for _ in 0..R_F / 2 {
239            for s in self.state.iter_mut() {
240                *s = s.add(constants[ci]);
241                ci += 1;
242            }
243            self.full_sbox();
244            self.external_linear();
245        }
246
247        debug_assert_eq!(ci, constants.len());
248    }
249}
250
251// ---------------------------------------------------------------------------
252// Sponge-based hasher
253// ---------------------------------------------------------------------------
254
255/// Poseidon2 sponge hasher (absorb / squeeze interface).
256pub struct Poseidon2Hasher {
257    state: Poseidon2Sponge,
258    absorbed: usize,
259}
260
261impl Poseidon2Hasher {
262    pub fn new() -> Self {
263        Self {
264            state: Poseidon2Sponge::new(),
265            absorbed: 0,
266        }
267    }
268
269    /// Absorb field elements into the sponge (rate portion of the state).
270    pub fn absorb(&mut self, elements: &[GoldilocksField]) {
271        for &elem in elements {
272            if self.absorbed == RATE {
273                self.state.permutation();
274                self.absorbed = 0;
275            }
276            self.state.state[self.absorbed] = self.state.state[self.absorbed].add(elem);
277            self.absorbed += 1;
278        }
279    }
280
281    /// Absorb raw bytes (7 bytes per element to stay below P).
282    pub fn absorb_bytes(&mut self, data: &[u8]) {
283        const BYTES_PER_ELEM: usize = 7;
284        let mut elements = Vec::with_capacity(data.len() / BYTES_PER_ELEM + 2);
285        for chunk in data.chunks(BYTES_PER_ELEM) {
286            let mut buf = [0u8; 8];
287            buf[..chunk.len()].copy_from_slice(chunk);
288            elements.push(GoldilocksField::new(u64::from_le_bytes(buf)));
289        }
290        // Length separator so [] and [0x00] hash differently.
291        elements.push(GoldilocksField::new(data.len() as u64));
292        self.absorb(&elements);
293    }
294
295    /// Squeeze `count` field elements out of the sponge.
296    pub fn squeeze(&mut self, count: usize) -> Vec<GoldilocksField> {
297        let mut out = Vec::with_capacity(count);
298        self.state.permutation();
299        self.absorbed = 0;
300        let mut squeezed = 0;
301        loop {
302            for &elem in self.state.state[..RATE].iter() {
303                out.push(elem);
304                squeezed += 1;
305                if squeezed == count {
306                    return out;
307                }
308            }
309            self.state.permutation();
310        }
311    }
312
313    /// Finalize and return a single field-element hash.
314    pub fn finalize(mut self) -> GoldilocksField {
315        self.squeeze(1)[0]
316    }
317
318    /// Finalize and return 4 field elements (256-bit equivalent).
319    pub fn finalize_4(mut self) -> [GoldilocksField; 4] {
320        let v = self.squeeze(4);
321        [v[0], v[1], v[2], v[3]]
322    }
323}
324
325// ---------------------------------------------------------------------------
326// Convenience helpers
327// ---------------------------------------------------------------------------
328
329/// Hash arbitrary bytes to a 256-bit content hash (32 bytes).
330pub fn hash_bytes(data: &[u8]) -> [u8; 32] {
331    let mut hasher = Poseidon2Hasher::new();
332    hasher.absorb_bytes(data);
333    let result = hasher.finalize_4();
334    let mut out = [0u8; 32];
335    for (i, elem) in result.iter().enumerate() {
336        out[i * 8..i * 8 + 8].copy_from_slice(&elem.0.to_le_bytes());
337    }
338    out
339}
340
341/// Hash a slice of field elements directly, returning 4 field elements.
342pub fn hash_fields(elements: &[GoldilocksField]) -> [GoldilocksField; 4] {
343    let mut hasher = Poseidon2Hasher::new();
344    hasher.absorb(elements);
345    hasher.finalize_4()
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_goldilocks_arithmetic() {
358        let a = GoldilocksField::new(P - 1);
359        let b = GoldilocksField::ONE;
360        // (p-1) + 1 = 0
361        assert_eq!(a.add(b), GoldilocksField::ZERO);
362        // 0 - 1 = p-1
363        assert_eq!(GoldilocksField::ZERO.sub(b), a);
364        // Multiplication identity and zero
365        let x = GoldilocksField::new(123456789);
366        assert_eq!(x.mul(GoldilocksField::ONE), x);
367        assert_eq!(x.mul(GoldilocksField::ZERO), GoldilocksField::ZERO);
368        // Commutativity
369        let y = GoldilocksField::new(987654321);
370        assert_eq!(x.mul(y), y.mul(x));
371        // Pow: x^0 = 1, x^1 = x, x^3 = x*x*x
372        assert_eq!(x.pow(0), GoldilocksField::ONE);
373        assert_eq!(x.pow(1), x);
374        assert_eq!(x.pow(3), x.mul(x).mul(x));
375        // (-1)^2 = 1
376        assert_eq!(a.mul(a), GoldilocksField::ONE);
377    }
378
379    #[test]
380    fn test_sbox() {
381        let x = GoldilocksField::new(42);
382        assert_eq!(x.sbox(), x.pow(ALPHA));
383        assert_eq!(GoldilocksField::ZERO.sbox(), GoldilocksField::ZERO);
384        assert_eq!(GoldilocksField::ONE.sbox(), GoldilocksField::ONE);
385        let z = GoldilocksField::new(1000);
386        assert_ne!(z.sbox(), z);
387        assert_eq!(z.sbox(), z.pow(7));
388    }
389
390    #[test]
391    fn test_permutation_deterministic() {
392        let input: [GoldilocksField; T] =
393            core::array::from_fn(|i| GoldilocksField::new(i as u64 + 1));
394        let mut s1 = Poseidon2Sponge { state: input };
395        let mut s2 = Poseidon2Sponge { state: input };
396        s1.permutation();
397        s2.permutation();
398        assert_eq!(s1.state, s2.state);
399    }
400
401    #[test]
402    fn test_permutation_diffusion() {
403        let base: [GoldilocksField; T] =
404            core::array::from_fn(|i| GoldilocksField::new(i as u64 + 100));
405        let mut s_base = Poseidon2Sponge { state: base };
406        s_base.permutation();
407
408        let mut tweaked = base;
409        tweaked[0] = tweaked[0].add(GoldilocksField::ONE);
410        let mut s_tweak = Poseidon2Sponge { state: tweaked };
411        s_tweak.permutation();
412
413        for i in 0..T {
414            assert_ne!(
415                s_base.state[i], s_tweak.state[i],
416                "Element {i} unchanged after input tweak"
417            );
418        }
419    }
420
421    #[test]
422    fn test_hash_bytes_deterministic() {
423        assert_eq!(hash_bytes(b"hello world"), hash_bytes(b"hello world"));
424    }
425
426    #[test]
427    fn test_hash_bytes_different_inputs() {
428        assert_ne!(hash_bytes(b"hello"), hash_bytes(b"world"));
429    }
430
431    #[test]
432    fn test_absorb_squeeze() {
433        let elems: Vec<GoldilocksField> =
434            (0..10).map(|i| GoldilocksField::new(i * 7 + 3)).collect();
435
436        let mut h1 = Poseidon2Hasher::new();
437        h1.absorb(&elems);
438        let out1 = h1.squeeze(4);
439
440        let mut h2 = Poseidon2Hasher::new();
441        h2.absorb(&elems);
442        let out2 = h2.squeeze(4);
443
444        assert_eq!(out1, out2);
445        assert!(out1.iter().any(|e| *e != GoldilocksField::ZERO));
446    }
447
448    #[test]
449    fn test_hash_fields() {
450        let elems: Vec<GoldilocksField> = (1..=5).map(GoldilocksField::new).collect();
451        assert_eq!(hash_fields(&elems), hash_fields(&elems));
452    }
453
454    #[test]
455    fn test_empty_hash() {
456        let h = hash_bytes(b"");
457        assert_eq!(h, hash_bytes(b""));
458        assert_ne!(h, [0u8; 32]);
459    }
460
461    #[test]
462    fn test_collision_resistance() {
463        let hashes: Vec<[u8; 32]> = (0u64..20).map(|i| hash_bytes(&i.to_le_bytes())).collect();
464        for i in 0..hashes.len() {
465            for j in i + 1..hashes.len() {
466                assert_ne!(hashes[i], hashes[j], "Collision between inputs {i} and {j}");
467            }
468        }
469    }
470}