1use std::{
2 arch::x86_64::{
3 __m512i, _mm512_and_si512, _mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512,
4 _mm512_ternarylogic_epi32, _mm512_test_epi64_mask, _mm512_xor_si512,
5 },
6 io::{Error, ErrorKind},
7 sync::OnceLock,
8};
9
10use wide::u64x8;
11
12use crate::transpose::transpose_scalar;
13
14pub mod benchmark;
15pub mod transpose;
16
17pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
18pub const ZERO: u64x8 = u64x8::ZERO;
19
20pub fn splat(n: u64) -> u64x8 {
21 u64x8::splat(n)
22}
23
24pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
27 let mut carry = u64x8::ZERO;
28 let mut sum = [u64x8::ZERO; 64];
29 for i in (0..64).rev() {
30 let res = calc_sum_carry(a[i], b[i], carry);
31 sum[i] = res.0;
32 carry = res.1;
34 }
35 sum
36}
37
38pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
39 let mut carry = u64x8::ZERO;
40 let mut sum = [u64x8::ZERO; 64];
41 for i in (0..64).rev() {
42 let shift_right = 63 - i;
43 let current_bit = (b >> shift_right) & 1;
44 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
45 let res = calc_sum_carry(a[i], b_i, carry);
46 sum[i] = res.0;
47 carry = res.1;
49 }
50 sum
51}
52
53pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
54 let mut carry = u64x8::ZERO;
55 for i in (0..64).rev() {
56 let res = calc_sum_carry(a[i], b[i], carry);
57 a[i] = res.0;
58 carry = res.1;
60 }
61}
62
63pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
64 let mut carry = u64x8::ZERO;
65 for i in (0..64).rev() {
66 let shift_right = 63 - i;
67 let current_bit = (b >> shift_right) & 1;
68 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
69 let res = calc_sum_carry(a[i], b_i, carry);
70 a[i] = res.0;
71 carry = res.1;
73 }
74}
75
76fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
77 let sum = a ^ b ^ carry;
78 let next_carry = (a & b) | (carry & (a ^ b));
79 (sum, next_carry)
80}
81
82const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
83
84const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
85
86pub unsafe fn bitsliced_add_single_inline_avx(a: &mut [__m512i; 64], b: u64) {
87 let mut carry = M512_ZERO;
88 let max_bit_pos = 64 - (b.leading_zeros() as usize);
89 for i in (0..64).rev() {
90 let bit_index = 63 - i;
91 if bit_index >= max_bit_pos {
93 if _mm512_test_epi64_mask(carry, carry) == 0 {
94 break;
95 }
96 }
97
98 let current_bit = if ((b >> bit_index) & 1) == 1 {
99 M512_ONES
100 } else {
101 M512_ZERO
102 };
103
104 let a_orig = a[i];
105
106 a[i] = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0x96);
107
108 carry = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0xE8);
109 }
110}
111
112pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
116 if k > 64 {
117 return Err(Error::new(
118 ErrorKind::InvalidData,
119 "k must be <= 64 for bitsliced modulo",
120 ));
121 }
122 let mut out = [u64x8::splat(0); 64];
123 let start: usize = 64 - k;
124 out[start..].copy_from_slice(&a[start..]);
125
126 Ok(out)
127}
128
129pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
130 if k > 64 {
131 return Err(Error::new(
132 ErrorKind::InvalidData,
133 "k must be <= 64 for bitsliced modulo",
134 ));
135 }
136 let end: usize = 64 - k;
137 for i in 0..end {
138 a[i] = u64x8::splat(0);
139 }
140
141 Ok(())
142}
143
144pub fn bitsliced_modulo_power_of_two_inline_avx(
145 a: &mut [__m512i; 64],
146 k: usize,
147) -> Result<(), Error> {
148 if k > 64 {
149 return Err(Error::new(
150 ErrorKind::InvalidData,
151 "k must be <= 64 for bitsliced modulo",
152 ));
153 }
154 let end: usize = 64 - k;
155 for i in 0..end {
156 a[i] = M512_ZERO
157 }
158
159 Ok(())
160}
161
162pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
165 let mut sum = bitsliced_add_single(h, i);
166 bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
167 sum
168}
169
170pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
171 bitsliced_add_single_inline(h, i);
172 bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
173}
174
175pub unsafe fn des_reduction_inline_avx(h: &mut [__m512i; 64], i: u64) {
176 unsafe { bitsliced_add_single_inline_avx(h, i) };
177 bitsliced_modulo_power_of_two_inline_avx(h, 56).unwrap();
178}
179
180static USE_GFNI: OnceLock<bool> = OnceLock::new();
181
182pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
185 if *USE_GFNI.get_or_init(|| {
186 #[cfg(target_arch = "x86_64")]
187 {
188 std::is_x86_feature_detected!("gfni")
189 && std::is_x86_feature_detected!("avx512f")
190 && std::is_x86_feature_detected!("avx512bw")
191 && std::is_x86_feature_detected!("avx512vbmi")
192 }
193 #[cfg(not(target_arch = "x86_64"))]
194 {
195 false
196 }
197 }) {
198 unsafe { crate::transpose::transpose_gfni(input) }
199 } else {
200 transpose_scalar(input)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::arch::x86_64::_mm512_storeu_si512;
207
208 use super::*;
209
210 #[test]
211 fn test_add_works() {
212 let mut a = [ZERO; 64];
213 a[63] = ALL_ONES;
214 let mut b = [ZERO; 64];
215 b[63] = ALL_ONES;
216 let sum = bitsliced_add(&a, &b);
217 assert_eq!(sum[63], ZERO);
218 assert_eq!(sum[62], ALL_ONES);
219 for i in 0..62 {
220 assert_eq!(sum[i], ZERO);
221 }
222 }
223
224 #[test]
225 fn test_add_single_works() {
226 let mut a = [ZERO; 64];
227 a[63] = ALL_ONES;
228 let sum = bitsliced_add_single(&a, 1);
229 assert_eq!(sum[63], ZERO);
230 assert_eq!(sum[62], ALL_ONES);
231 for i in 0..62 {
232 assert_eq!(sum[i], ZERO);
233 }
234 }
235
236 #[test]
237 fn test_add_inline_works() {
238 let mut a = [ZERO; 64];
239 a[63] = ALL_ONES;
240 let mut b = [ZERO; 64];
241 b[63] = ALL_ONES;
242 bitsliced_add_inline(&mut a, &b);
243 assert_eq!(a[63], ZERO);
244 assert_eq!(a[62], ALL_ONES);
245 for i in 0..62 {
246 assert_eq!(a[i], ZERO);
247 }
248 }
249
250 #[test]
251 fn test_add_single_inline_works() {
252 let mut a = [ZERO; 64];
253 a[63] = ALL_ONES;
254 bitsliced_add_single_inline(&mut a, 1);
255 assert_eq!(a[63], ZERO);
256 assert_eq!(a[62], ALL_ONES);
257 for i in 0..62 {
258 assert_eq!(a[i], ZERO);
259 }
260 }
261
262 #[test]
263 fn test_add_single_inline_avx_works() {
264 let mut a = [unsafe { _mm512_setzero_si512() }; 64];
265 a[63] = unsafe { std::mem::transmute([!0u64; 8]) };
266 unsafe { bitsliced_add_single_inline_avx(&mut a, 1) };
267 let mut arr = [0u64; 8];
268 unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[63]) };
269 assert_eq!(arr[0], 0);
270 unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[62]) };
271 assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
272 for i in 0..62 {
273 unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[i]) };
274 assert_eq!(arr[0], 0);
275 }
276 }
277
278 #[test]
279 fn test_modulo_works() {
280 let a = [ALL_ONES; 64];
281 let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
282 for i in 0..8 {
283 assert_eq!(res[i], ZERO);
284 }
285 for i in 8..64 {
286 assert_eq!(res[i], ALL_ONES);
287 }
288 }
289
290 #[test]
291 fn test_modulo_inline_works() {
292 let mut a = [ALL_ONES; 64];
293 let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
294 for i in 0..8 {
295 assert_eq!(a[i], ZERO);
296 }
297 for i in 8..64 {
298 assert_eq!(a[i], ALL_ONES);
299 }
300 }
301}