lib_q_poseidon/
permutation.rs1#[cfg(feature = "alloc")]
10extern crate alloc;
11
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15use lib_q_stark_field::PrimeCharacteristicRing;
16use lib_q_stark_field::extension::Complex;
17use lib_q_stark_mersenne31::Mersenne31;
18
19use crate::constants::sbox;
20use crate::params::PoseidonParams;
21
22type F = Complex<Mersenne31>;
24
25#[cfg(feature = "alloc")]
27pub type PoseidonState = Vec<F>;
28
29#[derive(Debug, Clone)]
34pub struct PoseidonPermutation {
35 params: PoseidonParams,
36}
37
38impl PoseidonPermutation {
39 pub fn new(params: PoseidonParams) -> Self {
41 let n = params.state_width;
42 assert!(
43 (2..=16).contains(&n),
44 "state_width must be in 2..=16, got {}",
45 n
46 );
47 let required = (params.full_rounds + params.partial_rounds) * n;
48 assert!(
49 params.round_constants.len() >= required,
50 "Insufficient round constants: need {}, have {}",
51 required,
52 params.round_constants.len()
53 );
54 assert_eq!(
55 params.mds_matrix.len(),
56 n,
57 "MDS matrix must have {} rows",
58 n
59 );
60 for (i, row) in params.mds_matrix.iter().enumerate() {
61 assert_eq!(row.len(), n, "MDS matrix row {} must have {} columns", i, n);
62 }
63 Self { params }
64 }
65
66 #[cfg(feature = "alloc")]
76 pub fn permute(&self, mut state: PoseidonState) -> PoseidonState {
77 let full_rounds_half = self.params.full_rounds / 2;
78 let mut round_const_idx = 0;
79
80 for _ in 0..full_rounds_half {
82 state = self.full_round(state, &mut round_const_idx);
83 }
84
85 for _ in 0..self.params.partial_rounds {
87 state = self.partial_round(state, &mut round_const_idx);
88 }
89
90 for _ in 0..full_rounds_half {
92 state = self.full_round(state, &mut round_const_idx);
93 }
94
95 state
96 }
97
98 #[cfg(feature = "alloc")]
100 fn full_round(&self, mut state: PoseidonState, round_const_idx: &mut usize) -> PoseidonState {
101 let n = self.params.state_width;
102 for (i, s) in state.iter_mut().enumerate().take(n) {
103 *s += self.params.round_constants[*round_const_idx + i];
104 }
105 *round_const_idx += n;
106 for s in state.iter_mut().take(n) {
107 *s = sbox(*s);
108 }
109 self.mix_layer(state)
110 }
111
112 #[cfg(feature = "alloc")]
114 fn partial_round(
115 &self,
116 mut state: PoseidonState,
117 round_const_idx: &mut usize,
118 ) -> PoseidonState {
119 let n = self.params.state_width;
120 for (i, s) in state.iter_mut().enumerate().take(n) {
121 *s += self.params.round_constants[*round_const_idx + i];
122 }
123 *round_const_idx += n;
124 state[0] = sbox(state[0]);
125 self.mix_layer(state)
126 }
127
128 #[cfg(feature = "alloc")]
130 fn mix_layer(&self, state: PoseidonState) -> PoseidonState {
131 let n = self.params.state_width;
132 let mds = &self.params.mds_matrix;
133 let mut new_state = alloc::vec![F::ZERO; n];
134 for i in 0..n {
135 for j in 0..n {
136 new_state[i] += mds[i][j] * state[j];
137 }
138 }
139 new_state
140 }
141
142 pub fn params(&self) -> &PoseidonParams {
144 &self.params
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use alloc::vec;
151
152 use super::*;
153 use crate::params::Poseidon128;
154
155 #[test]
156 fn test_permutation_idempotent() {
157 let perm = Poseidon128::permutation();
158 let state: PoseidonState = vec![
159 F::ONE,
160 F::from(Mersenne31::new(2)),
161 F::from(Mersenne31::new(3)),
162 F::from(Mersenne31::new(4)),
163 F::from(Mersenne31::new(5)),
164 ];
165 let permuted = perm.permute(state.clone());
166 assert_ne!(state, permuted);
167 }
168
169 #[test]
170 fn test_permutation_deterministic() {
171 let perm = Poseidon128::permutation();
172 let state: PoseidonState = vec![
173 F::ONE,
174 F::from(Mersenne31::new(2)),
175 F::from(Mersenne31::new(3)),
176 F::from(Mersenne31::new(4)),
177 F::from(Mersenne31::new(5)),
178 ];
179 let result1 = perm.permute(state.clone());
180 let result2 = perm.permute(state);
181 assert_eq!(result1, result2);
182 }
183}