1use std::{
2 arch::x86_64::{
3 __m256i, __m512i, _mm256_and_si256, _mm256_or_si256, _mm256_ternarylogic_epi32,
4 _mm256_test_epi64_mask, _mm256_testz_si256, _mm256_xor_si256, _mm512_and_si512,
5 _mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512, _mm512_ternarylogic_epi32,
6 _mm512_test_epi64_mask, _mm512_xor_si512,
7 },
8 io::{Error, ErrorKind},
9 sync::OnceLock,
10};
11
12use wide::u64x8;
13
14use crate::transpose::transpose_scalar;
15
16pub mod benchmark;
17pub mod transpose;
18
19pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
20pub const ZERO: u64x8 = u64x8::ZERO;
21
22pub fn splat(n: u64) -> u64x8 {
23 u64x8::splat(n)
24}
25
26pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
29 let mut carry = u64x8::ZERO;
30 let mut sum = [u64x8::ZERO; 64];
31 for i in (0..64).rev() {
32 let res = calc_sum_carry(a[i], b[i], carry);
33 sum[i] = res.0;
34 carry = res.1;
36 }
37 sum
38}
39
40pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
41 let mut carry = u64x8::ZERO;
42 let mut sum = [u64x8::ZERO; 64];
43 for i in (0..64).rev() {
44 let shift_right = 63 - i;
45 let current_bit = (b >> shift_right) & 1;
46 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
47 let res = calc_sum_carry(a[i], b_i, carry);
48 sum[i] = res.0;
49 carry = res.1;
51 }
52 sum
53}
54
55pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
56 let mut carry = u64x8::ZERO;
57 for i in (0..64).rev() {
58 let res = calc_sum_carry(a[i], b[i], carry);
59 a[i] = res.0;
60 carry = res.1;
62 }
63}
64
65pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
66 let mut carry = u64x8::ZERO;
67 for i in (0..64).rev() {
68 let shift_right = 63 - i;
69 let current_bit = (b >> shift_right) & 1;
70 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
71 let res = calc_sum_carry(a[i], b_i, carry);
72 a[i] = res.0;
73 carry = res.1;
75 }
76}
77
78fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
79 let sum = a ^ b ^ carry;
80 let next_carry = (a & b) | (carry & (a ^ b));
81 (sum, next_carry)
82}
83
84const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
85const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
86
87#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
88pub unsafe fn bitsliced_add_single_inline_avx_512(a: &mut [__m512i; 64], b: u64) {
89 let mut carry = M512_ZERO;
90 let max_bit_pos = 64 - (b.leading_zeros() as usize);
91 for i in (0..64).rev() {
92 let bit_index = 63 - i;
93 if bit_index >= max_bit_pos {
95 if _mm512_test_epi64_mask(carry, carry) == 0 {
96 break;
97 }
98 }
99
100 let current_bit = if ((b >> bit_index) & 1) == 1 {
101 M512_ONES
102 } else {
103 M512_ZERO
104 };
105
106 let a_orig = a[i];
107
108 a[i] = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0x96);
109
110 carry = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0xE8);
111 }
112}
113
114const M2_ONES: __m256i = unsafe { std::mem::transmute([!0u64; 4]) };
115const M2_ZERO: __m256i = unsafe { std::mem::transmute([0u64; 4]) };
116
117#[target_feature(enable = "avx2")]
118pub unsafe fn bitsliced_add_single_inline_avx_2(a: &mut [__m256i; 64], b: u64) {
119 let mut carry = M2_ZERO;
120 let max_bit_pos = 64 - (b.leading_zeros() as usize);
121 for i in (0..64).rev() {
122 let bit_index = 63 - i;
123 if bit_index >= max_bit_pos {
125 if _mm256_testz_si256(carry, carry) != 0 {
126 break;
127 }
128 }
129
130 let current_bit = if ((b >> bit_index) & 1) == 1 {
131 M2_ONES
132 } else {
133 M2_ZERO
134 };
135
136 let a_orig = a[i];
137
138 let xor_ab = _mm256_xor_si256(a_orig, current_bit);
139 let and_ab = _mm256_and_si256(a_orig, current_bit);
140 a[i] = _mm256_xor_si256(xor_ab, carry);
141
142 carry = _mm256_or_si256(and_ab, _mm256_and_si256(carry, xor_ab));
143 }
144}
145
146pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
150 if k > 64 {
151 return Err(Error::new(
152 ErrorKind::InvalidData,
153 "k must be <= 64 for bitsliced modulo",
154 ));
155 }
156 let mut out = [u64x8::splat(0); 64];
157 let start: usize = 64 - k;
158 out[start..].copy_from_slice(&a[start..]);
159
160 Ok(out)
161}
162
163pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
164 if k > 64 {
165 return Err(Error::new(
166 ErrorKind::InvalidData,
167 "k must be <= 64 for bitsliced modulo",
168 ));
169 }
170 let end: usize = 64 - k;
171 for i in 0..end {
172 a[i] = u64x8::splat(0);
173 }
174
175 Ok(())
176}
177
178#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
179pub fn bitsliced_modulo_power_of_two_inline_avx_512(
180 a: &mut [__m512i; 64],
181 k: usize,
182) -> Result<(), Error> {
183 if k > 64 {
184 return Err(Error::new(
185 ErrorKind::InvalidData,
186 "k must be <= 64 for bitsliced modulo",
187 ));
188 }
189 let end: usize = 64 - k;
190 for i in 0..end {
191 a[i] = M512_ZERO
192 }
193
194 Ok(())
195}
196#[target_feature(enable = "avx2")]
197pub fn bitsliced_modulo_power_of_two_inline_avx_2(
198 a: &mut [__m256i; 64],
199 k: usize,
200) -> Result<(), Error> {
201 if k > 64 {
202 return Err(Error::new(
203 ErrorKind::InvalidData,
204 "k must be <= 64 for bitsliced modulo",
205 ));
206 }
207 let end: usize = 64 - k;
208 for i in 0..end {
209 a[i] = M2_ZERO
210 }
211
212 Ok(())
213}
214
215pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
218 let mut sum = bitsliced_add_single(h, i);
219 bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
220 sum
221}
222
223pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
224 bitsliced_add_single_inline(h, i);
225 bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
226}
227
228#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
229pub unsafe fn des_reduction_inline_avx_512(h: &mut [__m512i; 64], i: u64) {
230 unsafe { bitsliced_add_single_inline_avx_512(h, i) };
231 bitsliced_modulo_power_of_two_inline_avx_512(h, 56).unwrap();
232}
233
234#[target_feature(enable = "avx2")]
235pub unsafe fn des_reduction_inline_avx_2(h: &mut [__m256i; 64], i: u64) {
236 unsafe { bitsliced_add_single_inline_avx_2(h, i) };
237 bitsliced_modulo_power_of_two_inline_avx_2(h, 56).unwrap();
238}
239
240static USE_GFNI: OnceLock<bool> = OnceLock::new();
241
242pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
245 if *USE_GFNI.get_or_init(|| {
246 #[cfg(target_arch = "x86_64")]
247 {
248 std::is_x86_feature_detected!("gfni")
249 && std::is_x86_feature_detected!("avx512f")
250 && std::is_x86_feature_detected!("avx512bw")
251 && std::is_x86_feature_detected!("avx512vbmi")
252 }
253 #[cfg(not(target_arch = "x86_64"))]
254 {
255 false
256 }
257 }) {
258 unsafe { crate::transpose::transpose_gfni(input) }
259 } else {
260 transpose_scalar(input)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use std::arch::x86_64::{_mm256_setzero_si256, _mm256_storeu_si256, _mm512_storeu_si512};
267
268 use super::*;
269
270 #[test]
271 fn test_add_works() {
272 let mut a = [ZERO; 64];
273 a[63] = ALL_ONES;
274 let mut b = [ZERO; 64];
275 b[63] = ALL_ONES;
276 let sum = bitsliced_add(&a, &b);
277 assert_eq!(sum[63], ZERO);
278 assert_eq!(sum[62], ALL_ONES);
279 for i in 0..62 {
280 assert_eq!(sum[i], ZERO);
281 }
282 }
283
284 #[test]
285 fn test_add_single_works() {
286 let mut a = [ZERO; 64];
287 a[63] = ALL_ONES;
288 let sum = bitsliced_add_single(&a, 1);
289 assert_eq!(sum[63], ZERO);
290 assert_eq!(sum[62], ALL_ONES);
291 for i in 0..62 {
292 assert_eq!(sum[i], ZERO);
293 }
294 }
295
296 #[test]
297 fn test_add_inline_works() {
298 let mut a = [ZERO; 64];
299 a[63] = ALL_ONES;
300 let mut b = [ZERO; 64];
301 b[63] = ALL_ONES;
302 bitsliced_add_inline(&mut a, &b);
303 assert_eq!(a[63], ZERO);
304 assert_eq!(a[62], ALL_ONES);
305 for i in 0..62 {
306 assert_eq!(a[i], ZERO);
307 }
308 }
309
310 #[test]
311 fn test_add_single_inline_works() {
312 let mut a = [ZERO; 64];
313 a[63] = ALL_ONES;
314 bitsliced_add_single_inline(&mut a, 1);
315 assert_eq!(a[63], ZERO);
316 assert_eq!(a[62], ALL_ONES);
317 for i in 0..62 {
318 assert_eq!(a[i], ZERO);
319 }
320 }
321
322 #[test]
323 fn test_add_single_inline_avx_512_works() {
324 let mut a = [unsafe { _mm512_setzero_si512() }; 64];
325 a[63] = unsafe { std::mem::transmute([!0u64; 8]) };
326 unsafe { bitsliced_add_single_inline_avx_512(&mut a, 1) };
327 let mut arr = [0u64; 8];
328 unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[63]) };
329 assert_eq!(arr[0], 0);
330 unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[62]) };
331 assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
332 for i in 0..62 {
333 unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[i]) };
334 assert_eq!(arr[0], 0);
335 }
336 }
337
338 #[test]
339 fn test_add_single_inline_avx_2_works() {
340 let mut a = [unsafe { _mm256_setzero_si256() }; 64];
341 a[63] = unsafe { std::mem::transmute([!0u64; 4]) };
342 unsafe { bitsliced_add_single_inline_avx_2(&mut a, 1) };
343 let mut arr = [0u64; 8];
344 unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[63]) };
345 assert_eq!(arr[0], 0);
346 unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[62]) };
347 assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
348 for i in 0..62 {
349 unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[i]) };
350 assert_eq!(arr[0], 0);
351 }
352 }
353
354 #[test]
355 fn test_modulo_works() {
356 let a = [ALL_ONES; 64];
357 let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
358 for i in 0..8 {
359 assert_eq!(res[i], ZERO);
360 }
361 for i in 8..64 {
362 assert_eq!(res[i], ALL_ONES);
363 }
364 }
365
366 #[test]
367 fn test_modulo_inline_works() {
368 let mut a = [ALL_ONES; 64];
369 let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
370 for i in 0..8 {
371 assert_eq!(a[i], ZERO);
372 }
373 for i in 8..64 {
374 assert_eq!(a[i], ALL_ONES);
375 }
376 }
377
378 #[test]
379 fn test_modulo_inline_avx_2_works() {
380 unsafe {
381 let mut a = [M2_ONES; 64];
382 let _ = bitsliced_modulo_power_of_two_inline_avx_2(&mut a, 56).unwrap();
383
384 let zero_raw: [u8; 32] = std::mem::transmute(M2_ZERO);
385 let ones_raw: [u8; 32] = std::mem::transmute(M2_ONES);
386
387 for i in 0..8 {
388 let actual: [u8; 32] = std::mem::transmute(a[i]);
389 assert_eq!(actual, zero_raw, "Index {} should be ZERO", i);
390 }
391 for i in 8..64 {
392 let actual: [u8; 32] = std::mem::transmute(a[i]);
393 assert_eq!(actual, ones_raw, "Index {} should be ONES", i);
394 }
395 }
396 }
397}