1use p3_field::{AbstractField, PrimeField};
2use p3_mds::MdsPermutation;
3use p3_symmetric::Permutation;
4use serde::{Deserialize, Serialize};
5
6extern crate alloc;
7
8pub trait MdsLightPermutation<T: Clone, const WIDTH: usize>: Permutation<[T; WIDTH]> {}
11
12fn apply_hl_mat4<AF>(x: &mut [AF; 4])
20where
21 AF: AbstractField,
22{
23 let t0 = x[0].clone() + x[1].clone();
24 let t1 = x[2].clone() + x[3].clone();
25 let t2 = x[1].clone() + x[1].clone() + t1.clone();
26 let t3 = x[3].clone() + x[3].clone() + t0.clone();
27 let t4 = t1.double().double() + t3.clone();
28 let t5 = t0.double().double() + t2.clone();
29 let t6 = t3 + t5.clone();
30 let t7 = t2 + t4.clone();
31 x[0] = t6;
32 x[1] = t5;
33 x[2] = t7;
34 x[3] = t4;
35}
36
37fn apply_mat4<AF>(x: &mut [AF; 4])
44where
45 AF: AbstractField,
46{
47 let t01 = x[0].clone() + x[1].clone();
48 let t23 = x[2].clone() + x[3].clone();
49 let t0123 = t01.clone() + t23.clone();
50 let t01123 = t0123.clone() + x[1].clone();
51 let t01233 = t0123.clone() + x[3].clone();
52 x[3] = t01233.clone() + x[0].double(); x[1] = t01123.clone() + x[2].double(); x[0] = t01123 + t01; x[2] = t01233 + t23; }
58
59#[derive(Clone, Default)]
61pub struct HLMDSMat4;
62
63impl<AF: AbstractField> Permutation<[AF; 4]> for HLMDSMat4 {
64 fn permute(&self, input: [AF; 4]) -> [AF; 4] {
65 let mut output = input.clone();
66 self.permute_mut(&mut output);
67 output
68 }
69
70 fn permute_mut(&self, input: &mut [AF; 4]) {
71 apply_hl_mat4(input)
72 }
73}
74impl<AF: AbstractField> MdsPermutation<AF, 4> for HLMDSMat4 {}
75
76#[derive(Clone, Default)]
77pub struct MDSMat4;
78
79impl<AF: AbstractField> Permutation<[AF; 4]> for MDSMat4 {
80 fn permute(&self, input: [AF; 4]) -> [AF; 4] {
81 let mut output = input.clone();
82 self.permute_mut(&mut output);
83 output
84 }
85
86 fn permute_mut(&self, input: &mut [AF; 4]) {
87 apply_mat4(input)
88 }
89}
90impl<AF: AbstractField> MdsPermutation<AF, 4> for MDSMat4 {}
91
92fn mds_light_permutation<AF: AbstractField, MdsPerm4: MdsPermutation<AF, 4>, const WIDTH: usize>(
93 state: &mut [AF; WIDTH],
94 mdsmat: MdsPerm4,
95) {
96 match WIDTH {
97 2 => {
98 let sum = state[0].clone() + state[1].clone();
99 state[0] += sum.clone();
100 state[1] += sum;
101 }
102
103 3 => {
104 let sum = state[0].clone() + state[1].clone() + state[2].clone();
105 state[0] += sum.clone();
106 state[1] += sum.clone();
107 state[2] += sum;
108 }
109
110 4 | 8 | 12 | 16 | 20 | 24 => {
111 for i in (0..WIDTH).step_by(4) {
114 let mut state_4 = [
116 state[i].clone(),
117 state[i + 1].clone(),
118 state[i + 2].clone(),
119 state[i + 3].clone(),
120 ];
121 mdsmat.permute_mut(&mut state_4);
122 state[i..i + 4].clone_from_slice(&state_4);
123 }
124 let sums: [AF; 4] = core::array::from_fn(|k| {
128 (0..WIDTH)
129 .step_by(4)
130 .map(|j| state[j + k].clone())
131 .sum::<AF>()
132 });
133
134 for i in 0..WIDTH {
137 state[i] += sums[i % 4].clone();
138 }
139 }
140
141 _ => {
142 panic!("Unsupported width");
143 }
144 }
145}
146
147#[derive(Default, Clone, Serialize, Deserialize)]
148pub struct Poseidon2ExternalMatrixGeneral;
149
150impl<AF, const WIDTH: usize> Permutation<[AF; WIDTH]> for Poseidon2ExternalMatrixGeneral
151where
152 AF: AbstractField,
153 AF::F: PrimeField,
154{
155 fn permute_mut(&self, state: &mut [AF; WIDTH]) {
156 mds_light_permutation::<AF, MDSMat4, WIDTH>(state, MDSMat4)
157 }
158}
159
160impl<AF, const WIDTH: usize> MdsLightPermutation<AF, WIDTH> for Poseidon2ExternalMatrixGeneral
161where
162 AF: AbstractField,
163 AF::F: PrimeField,
164{
165}
166
167#[derive(Default, Clone)]
168pub struct Poseidon2ExternalMatrixHL;
169
170impl<AF, const WIDTH: usize> Permutation<[AF; WIDTH]> for Poseidon2ExternalMatrixHL
171where
172 AF: AbstractField,
173 AF::F: PrimeField,
174{
175 fn permute_mut(&self, state: &mut [AF; WIDTH]) {
176 mds_light_permutation::<AF, HLMDSMat4, WIDTH>(state, HLMDSMat4)
177 }
178}
179
180impl<AF, const WIDTH: usize> MdsLightPermutation<AF, WIDTH> for Poseidon2ExternalMatrixHL
181where
182 AF: AbstractField,
183 AF::F: PrimeField,
184{
185}