Skip to main content

lib_q_poseidon/
permutation.rs

1//! Poseidon permutation implementation
2//!
3//! This module implements the core Poseidon permutation function,
4//! which consists of:
5//! 1. AddRoundConstants (ARC)
6//! 2. SubWords (S-box)
7//! 3. MixLayer (MDS matrix multiplication)
8
9#[cfg(feature = "alloc")]
10extern crate alloc;
11
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15use lib_q_stark_field::PrimeCharacteristicRing;
16use lib_q_stark_field::extension::Complex;
17use lib_q_stark_mersenne31::Mersenne31;
18
19use crate::constants::sbox;
20use crate::params::PoseidonParams;
21
22/// Field type for permutation
23type F = Complex<Mersenne31>;
24
25/// Poseidon permutation state (variable length: state_width elements)
26#[cfg(feature = "alloc")]
27pub type PoseidonState = Vec<F>;
28
29/// Poseidon permutation function
30///
31/// This implements the full Poseidon permutation with configurable
32/// round counts and state width for different security levels.
33#[derive(Debug, Clone)]
34pub struct PoseidonPermutation {
35    params: PoseidonParams,
36}
37
38impl PoseidonPermutation {
39    /// Create a new Poseidon permutation with the given parameters
40    pub fn new(params: PoseidonParams) -> Self {
41        let n = params.state_width;
42        assert!(
43            (2..=16).contains(&n),
44            "state_width must be in 2..=16, got {}",
45            n
46        );
47        let required = (params.full_rounds + params.partial_rounds) * n;
48        assert!(
49            params.round_constants.len() >= required,
50            "Insufficient round constants: need {}, have {}",
51            required,
52            params.round_constants.len()
53        );
54        assert_eq!(
55            params.mds_matrix.len(),
56            n,
57            "MDS matrix must have {} rows",
58            n
59        );
60        for (i, row) in params.mds_matrix.iter().enumerate() {
61            assert_eq!(row.len(), n, "MDS matrix row {} must have {} columns", i, n);
62        }
63        Self { params }
64    }
65
66    /// Apply the Poseidon permutation to the state
67    ///
68    /// # Arguments
69    ///
70    /// * `state` - The state to permute (state_width field elements)
71    ///
72    /// # Returns
73    ///
74    /// The permuted state
75    #[cfg(feature = "alloc")]
76    pub fn permute(&self, mut state: PoseidonState) -> PoseidonState {
77        let full_rounds_half = self.params.full_rounds / 2;
78        let mut round_const_idx = 0;
79
80        // First half of full rounds
81        for _ in 0..full_rounds_half {
82            state = self.full_round(state, &mut round_const_idx);
83        }
84
85        // Partial rounds
86        for _ in 0..self.params.partial_rounds {
87            state = self.partial_round(state, &mut round_const_idx);
88        }
89
90        // Second half of full rounds
91        for _ in 0..full_rounds_half {
92            state = self.full_round(state, &mut round_const_idx);
93        }
94
95        state
96    }
97
98    /// Apply a full round (S-box on all elements)
99    #[cfg(feature = "alloc")]
100    fn full_round(&self, mut state: PoseidonState, round_const_idx: &mut usize) -> PoseidonState {
101        let n = self.params.state_width;
102        for (i, s) in state.iter_mut().enumerate().take(n) {
103            *s += self.params.round_constants[*round_const_idx + i];
104        }
105        *round_const_idx += n;
106        for s in state.iter_mut().take(n) {
107            *s = sbox(*s);
108        }
109        self.mix_layer(state)
110    }
111
112    /// Apply a partial round (S-box only on first element)
113    #[cfg(feature = "alloc")]
114    fn partial_round(
115        &self,
116        mut state: PoseidonState,
117        round_const_idx: &mut usize,
118    ) -> PoseidonState {
119        let n = self.params.state_width;
120        for (i, s) in state.iter_mut().enumerate().take(n) {
121            *s += self.params.round_constants[*round_const_idx + i];
122        }
123        *round_const_idx += n;
124        state[0] = sbox(state[0]);
125        self.mix_layer(state)
126    }
127
128    /// Apply the MDS matrix multiplication (linear layer)
129    #[cfg(feature = "alloc")]
130    fn mix_layer(&self, state: PoseidonState) -> PoseidonState {
131        let n = self.params.state_width;
132        let mds = &self.params.mds_matrix;
133        let mut new_state = alloc::vec![F::ZERO; n];
134        for i in 0..n {
135            for j in 0..n {
136                new_state[i] += mds[i][j] * state[j];
137            }
138        }
139        new_state
140    }
141
142    /// Get a reference to the parameters
143    pub fn params(&self) -> &PoseidonParams {
144        &self.params
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use alloc::vec;
151
152    use super::*;
153    use crate::params::Poseidon128;
154
155    #[test]
156    fn test_permutation_idempotent() {
157        let perm = Poseidon128::permutation();
158        let state: PoseidonState = vec![
159            F::ONE,
160            F::from(Mersenne31::new(2)),
161            F::from(Mersenne31::new(3)),
162            F::from(Mersenne31::new(4)),
163            F::from(Mersenne31::new(5)),
164        ];
165        let permuted = perm.permute(state.clone());
166        assert_ne!(state, permuted);
167    }
168
169    #[test]
170    fn test_permutation_deterministic() {
171        let perm = Poseidon128::permutation();
172        let state: PoseidonState = vec![
173            F::ONE,
174            F::from(Mersenne31::new(2)),
175            F::from(Mersenne31::new(3)),
176            F::from(Mersenne31::new(4)),
177            F::from(Mersenne31::new(5)),
178        ];
179        let result1 = perm.permute(state.clone());
180        let result2 = perm.permute(state);
181        assert_eq!(result1, result2);
182    }
183}