Skip to main content

p3_poseidon2/
matrix.rs

1use p3_field::{AbstractField, PrimeField};
2use p3_mds::MdsPermutation;
3use p3_symmetric::Permutation;
4use serde::{Deserialize, Serialize};
5
6extern crate alloc;
7
8/// For the external layers we use a matrix of the form circ(2M_4, M_4, ..., M_4)
9/// Where M_4 is a 4 x 4 MDS matrix. This leads to a permutation which has slightly weaker properties to MDS
10pub trait MdsLightPermutation<T: Clone, const WIDTH: usize>: Permutation<[T; WIDTH]> {}
11
12// Multiply a 4-element vector x by
13// [ 5 7 1 3 ]
14// [ 4 6 1 1 ]
15// [ 1 3 5 7 ]
16// [ 1 1 4 6 ].
17// This uses the formula from the start of Appendix B in the Poseidon2 paper, with multiplications unrolled into additions.
18// It is also the matrix used by the Horizon Labs implementation.
19fn apply_hl_mat4<AF>(x: &mut [AF; 4])
20where
21    AF: AbstractField,
22{
23    let t0 = x[0].clone() + x[1].clone();
24    let t1 = x[2].clone() + x[3].clone();
25    let t2 = x[1].clone() + x[1].clone() + t1.clone();
26    let t3 = x[3].clone() + x[3].clone() + t0.clone();
27    let t4 = t1.double().double() + t3.clone();
28    let t5 = t0.double().double() + t2.clone();
29    let t6 = t3 + t5.clone();
30    let t7 = t2 + t4.clone();
31    x[0] = t6;
32    x[1] = t5;
33    x[2] = t7;
34    x[3] = t4;
35}
36
37// Multiply a 4-element vector x by:
38// [ 2 3 1 1 ]
39// [ 1 2 3 1 ]
40// [ 1 1 2 3 ]
41// [ 3 1 1 2 ].
42// This is more efficient than the previous matrix.
43fn apply_mat4<AF>(x: &mut [AF; 4])
44where
45    AF: AbstractField,
46{
47    let t01 = x[0].clone() + x[1].clone();
48    let t23 = x[2].clone() + x[3].clone();
49    let t0123 = t01.clone() + t23.clone();
50    let t01123 = t0123.clone() + x[1].clone();
51    let t01233 = t0123.clone() + x[3].clone();
52    // The order here is important. Need to overwrite x[0] and x[2] after x[1] and x[3].
53    x[3] = t01233.clone() + x[0].double(); // 3*x[0] + x[1] + x[2] + 2*x[3]
54    x[1] = t01123.clone() + x[2].double(); // x[0] + 2*x[1] + 3*x[2] + x[3]
55    x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3]
56    x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
57}
58
59// The 4x4 MDS matrix used by the Horizon Labs implementation of Poseidon2.
60#[derive(Clone, Default)]
61pub struct HLMDSMat4;
62
63impl<AF: AbstractField> Permutation<[AF; 4]> for HLMDSMat4 {
64    fn permute(&self, input: [AF; 4]) -> [AF; 4] {
65        let mut output = input.clone();
66        self.permute_mut(&mut output);
67        output
68    }
69
70    fn permute_mut(&self, input: &mut [AF; 4]) {
71        apply_hl_mat4(input)
72    }
73}
74impl<AF: AbstractField> MdsPermutation<AF, 4> for HLMDSMat4 {}
75
76#[derive(Clone, Default)]
77pub struct MDSMat4;
78
79impl<AF: AbstractField> Permutation<[AF; 4]> for MDSMat4 {
80    fn permute(&self, input: [AF; 4]) -> [AF; 4] {
81        let mut output = input.clone();
82        self.permute_mut(&mut output);
83        output
84    }
85
86    fn permute_mut(&self, input: &mut [AF; 4]) {
87        apply_mat4(input)
88    }
89}
90impl<AF: AbstractField> MdsPermutation<AF, 4> for MDSMat4 {}
91
92fn mds_light_permutation<AF: AbstractField, MdsPerm4: MdsPermutation<AF, 4>, const WIDTH: usize>(
93    state: &mut [AF; WIDTH],
94    mdsmat: MdsPerm4,
95) {
96    match WIDTH {
97        2 => {
98            let sum = state[0].clone() + state[1].clone();
99            state[0] += sum.clone();
100            state[1] += sum;
101        }
102
103        3 => {
104            let sum = state[0].clone() + state[1].clone() + state[2].clone();
105            state[0] += sum.clone();
106            state[1] += sum.clone();
107            state[2] += sum;
108        }
109
110        4 | 8 | 12 | 16 | 20 | 24 => {
111            // First, we apply M_4 to each consecutive four elements of the state.
112            // In Appendix B's terminology, this replaces each x_i with x_i'.
113            for i in (0..WIDTH).step_by(4) {
114                // Would be nice to find a better way to do this.
115                let mut state_4 = [
116                    state[i].clone(),
117                    state[i + 1].clone(),
118                    state[i + 2].clone(),
119                    state[i + 3].clone(),
120                ];
121                mdsmat.permute_mut(&mut state_4);
122                state[i..i + 4].clone_from_slice(&state_4);
123            }
124            // Now, we apply the outer circulant matrix (to compute the y_i values).
125
126            // We first precompute the four sums of every four elements.
127            let sums: [AF; 4] = core::array::from_fn(|k| {
128                (0..WIDTH)
129                    .step_by(4)
130                    .map(|j| state[j + k].clone())
131                    .sum::<AF>()
132            });
133
134            // The formula for each y_i involves 2x_i' term and x_j' terms for each j that equals i mod 4.
135            // In other words, we can add a single copy of x_i' to the appropriate one of our precomputed sums
136            for i in 0..WIDTH {
137                state[i] += sums[i % 4].clone();
138            }
139        }
140
141        _ => {
142            panic!("Unsupported width");
143        }
144    }
145}
146
147#[derive(Default, Clone, Serialize, Deserialize)]
148pub struct Poseidon2ExternalMatrixGeneral;
149
150impl<AF, const WIDTH: usize> Permutation<[AF; WIDTH]> for Poseidon2ExternalMatrixGeneral
151where
152    AF: AbstractField,
153    AF::F: PrimeField,
154{
155    fn permute_mut(&self, state: &mut [AF; WIDTH]) {
156        mds_light_permutation::<AF, MDSMat4, WIDTH>(state, MDSMat4)
157    }
158}
159
160impl<AF, const WIDTH: usize> MdsLightPermutation<AF, WIDTH> for Poseidon2ExternalMatrixGeneral
161where
162    AF: AbstractField,
163    AF::F: PrimeField,
164{
165}
166
167#[derive(Default, Clone)]
168pub struct Poseidon2ExternalMatrixHL;
169
170impl<AF, const WIDTH: usize> Permutation<[AF; WIDTH]> for Poseidon2ExternalMatrixHL
171where
172    AF: AbstractField,
173    AF::F: PrimeField,
174{
175    fn permute_mut(&self, state: &mut [AF; WIDTH]) {
176        mds_light_permutation::<AF, HLMDSMat4, WIDTH>(state, HLMDSMat4)
177    }
178}
179
180impl<AF, const WIDTH: usize> MdsLightPermutation<AF, WIDTH> for Poseidon2ExternalMatrixHL
181where
182    AF: AbstractField,
183    AF::F: PrimeField,
184{
185}