Skip to main content

bitsliced_op/
lib.rs

1use std::{
2    arch::x86_64::{
3        __m256i, __m512i, _mm256_and_si256, _mm256_or_si256, _mm256_ternarylogic_epi32,
4        _mm256_test_epi64_mask, _mm256_testz_si256, _mm256_xor_si256, _mm512_and_si512,
5        _mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512, _mm512_ternarylogic_epi32,
6        _mm512_test_epi64_mask, _mm512_xor_si512,
7    },
8    io::{Error, ErrorKind},
9    sync::OnceLock,
10};
11
12use wide::u64x8;
13
14use crate::transpose::transpose_scalar;
15
16pub mod benchmark;
17pub mod transpose;
18
19pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
20pub const ZERO: u64x8 = u64x8::ZERO;
21
22pub fn splat(n: u64) -> u64x8 {
23    u64x8::splat(n)
24}
25
26//expects the input to be in bitsliced form e.g integers are columns, not rows
27//last row is LSB
28pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
29    let mut carry = u64x8::ZERO;
30    let mut sum = [u64x8::ZERO; 64];
31    for i in (0..64).rev() {
32        let res = calc_sum_carry(a[i], b[i], carry);
33        sum[i] = res.0;
34        //only set carry if we haven't reached the end yet, we currently ignore overflows
35        carry = res.1;
36    }
37    sum
38}
39
40pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
41    let mut carry = u64x8::ZERO;
42    let mut sum = [u64x8::ZERO; 64];
43    for i in (0..64).rev() {
44        let shift_right = 63 - i;
45        let current_bit = (b >> shift_right) & 1;
46        let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
47        let res = calc_sum_carry(a[i], b_i, carry);
48        sum[i] = res.0;
49        //only set carry if we haven't reached the end yet, we currently ignore overflows
50        carry = res.1;
51    }
52    sum
53}
54
55pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
56    let mut carry = u64x8::ZERO;
57    for i in (0..64).rev() {
58        let res = calc_sum_carry(a[i], b[i], carry);
59        a[i] = res.0;
60        //only set carry if we haven't reached the end yet, we currently ignore overflows
61        carry = res.1;
62    }
63}
64
65pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
66    let mut carry = u64x8::ZERO;
67    for i in (0..64).rev() {
68        let shift_right = 63 - i;
69        let current_bit = (b >> shift_right) & 1;
70        let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
71        let res = calc_sum_carry(a[i], b_i, carry);
72        a[i] = res.0;
73        //only set carry if we haven't reached the end yet, we currently ignore overflows
74        carry = res.1;
75    }
76}
77
78fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
79    let sum = a ^ b ^ carry;
80    let next_carry = (a & b) | (carry & (a ^ b));
81    (sum, next_carry)
82}
83
84const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
85const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
86
87#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
88pub unsafe fn bitsliced_add_single_inline_avx_512(a: &mut [__m512i; 64], b: u64) {
89    let mut carry = M512_ZERO;
90    let max_bit_pos = 64 - (b.leading_zeros() as usize);
91    for i in (0..64).rev() {
92        let bit_index = 63 - i;
93        //break early to save cpu cycles
94        if bit_index >= max_bit_pos {
95            if _mm512_test_epi64_mask(carry, carry) == 0 {
96                break;
97            }
98        }
99
100        let current_bit = if ((b >> bit_index) & 1) == 1 {
101            M512_ONES
102        } else {
103            M512_ZERO
104        };
105
106        let a_orig = a[i];
107
108        a[i] = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0x96);
109
110        carry = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0xE8);
111    }
112}
113
114const M2_ONES: __m256i = unsafe { std::mem::transmute([!0u64; 4]) };
115const M2_ZERO: __m256i = unsafe { std::mem::transmute([0u64; 4]) };
116
117#[target_feature(enable = "avx2")]
118pub unsafe fn bitsliced_add_single_inline_avx_2(a: &mut [__m256i; 64], b: u64) {
119    let mut carry = M2_ZERO;
120    let max_bit_pos = 64 - (b.leading_zeros() as usize);
121    for i in (0..64).rev() {
122        let bit_index = 63 - i;
123        //break early to save cpu cycles
124        if bit_index >= max_bit_pos {
125            if _mm256_testz_si256(carry, carry) != 0 {
126                break;
127            }
128        }
129
130        let current_bit = if ((b >> bit_index) & 1) == 1 {
131            M2_ONES
132        } else {
133            M2_ZERO
134        };
135
136        let a_orig = a[i];
137
138        let xor_ab = _mm256_xor_si256(a_orig, current_bit);
139        let and_ab = _mm256_and_si256(a_orig, current_bit);
140        a[i] = _mm256_xor_si256(xor_ab, carry);
141
142        carry = _mm256_or_si256(and_ab, _mm256_and_si256(carry, xor_ab));
143    }
144}
145
146//this function only works when calculating the module with a number of the power of two
147//currently only supports a single modulo operation for all integers
148//example: if you want to calculate the modulo with 2^56, pass 56 to k
149pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
150    if k > 64 {
151        return Err(Error::new(
152            ErrorKind::InvalidData,
153            "k must be <= 64 for bitsliced modulo",
154        ));
155    }
156    let mut out = [u64x8::splat(0); 64];
157    let start: usize = 64 - k;
158    out[start..].copy_from_slice(&a[start..]);
159
160    Ok(out)
161}
162
163pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
164    if k > 64 {
165        return Err(Error::new(
166            ErrorKind::InvalidData,
167            "k must be <= 64 for bitsliced modulo",
168        ));
169    }
170    let end: usize = 64 - k;
171    for i in 0..end {
172        a[i] = u64x8::splat(0);
173    }
174
175    Ok(())
176}
177
178#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
179pub fn bitsliced_modulo_power_of_two_inline_avx_512(
180    a: &mut [__m512i; 64],
181    k: usize,
182) -> Result<(), Error> {
183    if k > 64 {
184        return Err(Error::new(
185            ErrorKind::InvalidData,
186            "k must be <= 64 for bitsliced modulo",
187        ));
188    }
189    let end: usize = 64 - k;
190    for i in 0..end {
191        a[i] = M512_ZERO
192    }
193
194    Ok(())
195}
196#[target_feature(enable = "avx2")]
197pub fn bitsliced_modulo_power_of_two_inline_avx_2(
198    a: &mut [__m256i; 64],
199    k: usize,
200) -> Result<(), Error> {
201    if k > 64 {
202        return Err(Error::new(
203            ErrorKind::InvalidData,
204            "k must be <= 64 for bitsliced modulo",
205        ));
206    }
207    let end: usize = 64 - k;
208    for i in 0..end {
209        a[i] = M2_ZERO
210    }
211
212    Ok(())
213}
214
215//reduction function: (H+I)%MAX_SIZE
216//H=Hash,I=Index in chain,MAX_SIZE=Max size of output in power of 2
217pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
218    let mut sum = bitsliced_add_single(h, i);
219    bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
220    sum
221}
222
223pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
224    bitsliced_add_single_inline(h, i);
225    bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
226}
227
228#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
229pub unsafe fn des_reduction_inline_avx_512(h: &mut [__m512i; 64], i: u64) {
230    unsafe { bitsliced_add_single_inline_avx_512(h, i) };
231    bitsliced_modulo_power_of_two_inline_avx_512(h, 56).unwrap();
232}
233
234#[target_feature(enable = "avx2")]
235pub unsafe fn des_reduction_inline_avx_2(h: &mut [__m256i; 64], i: u64) {
236    unsafe { bitsliced_add_single_inline_avx_2(h, i) };
237    bitsliced_modulo_power_of_two_inline_avx_2(h, 56).unwrap();
238}
239
240static USE_GFNI: OnceLock<bool> = OnceLock::new();
241
242//transpose 64x64 bit matrix
243//use gfni if the cpu supports it, fallback to scalar if it doesn't
244pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
245    if *USE_GFNI.get_or_init(|| {
246        #[cfg(target_arch = "x86_64")]
247        {
248            std::is_x86_feature_detected!("gfni")
249                && std::is_x86_feature_detected!("avx512f")
250                && std::is_x86_feature_detected!("avx512bw")
251                && std::is_x86_feature_detected!("avx512vbmi")
252        }
253        #[cfg(not(target_arch = "x86_64"))]
254        {
255            false
256        }
257    }) {
258        unsafe { crate::transpose::transpose_gfni(input) }
259    } else {
260        transpose_scalar(input)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use std::arch::x86_64::{_mm256_setzero_si256, _mm256_storeu_si256, _mm512_storeu_si512};
267
268    use super::*;
269
270    #[test]
271    fn test_add_works() {
272        let mut a = [ZERO; 64];
273        a[63] = ALL_ONES;
274        let mut b = [ZERO; 64];
275        b[63] = ALL_ONES;
276        let sum = bitsliced_add(&a, &b);
277        assert_eq!(sum[63], ZERO);
278        assert_eq!(sum[62], ALL_ONES);
279        for i in 0..62 {
280            assert_eq!(sum[i], ZERO);
281        }
282    }
283
284    #[test]
285    fn test_add_single_works() {
286        let mut a = [ZERO; 64];
287        a[63] = ALL_ONES;
288        let sum = bitsliced_add_single(&a, 1);
289        assert_eq!(sum[63], ZERO);
290        assert_eq!(sum[62], ALL_ONES);
291        for i in 0..62 {
292            assert_eq!(sum[i], ZERO);
293        }
294    }
295
296    #[test]
297    fn test_add_inline_works() {
298        let mut a = [ZERO; 64];
299        a[63] = ALL_ONES;
300        let mut b = [ZERO; 64];
301        b[63] = ALL_ONES;
302        bitsliced_add_inline(&mut a, &b);
303        assert_eq!(a[63], ZERO);
304        assert_eq!(a[62], ALL_ONES);
305        for i in 0..62 {
306            assert_eq!(a[i], ZERO);
307        }
308    }
309
310    #[test]
311    fn test_add_single_inline_works() {
312        let mut a = [ZERO; 64];
313        a[63] = ALL_ONES;
314        bitsliced_add_single_inline(&mut a, 1);
315        assert_eq!(a[63], ZERO);
316        assert_eq!(a[62], ALL_ONES);
317        for i in 0..62 {
318            assert_eq!(a[i], ZERO);
319        }
320    }
321
322    #[test]
323    fn test_add_single_inline_avx_512_works() {
324        let mut a = [unsafe { _mm512_setzero_si512() }; 64];
325        a[63] = unsafe { std::mem::transmute([!0u64; 8]) };
326        unsafe { bitsliced_add_single_inline_avx_512(&mut a, 1) };
327        let mut arr = [0u64; 8];
328        unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[63]) };
329        assert_eq!(arr[0], 0);
330        unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[62]) };
331        assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
332        for i in 0..62 {
333            unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[i]) };
334            assert_eq!(arr[0], 0);
335        }
336    }
337
338    #[test]
339    fn test_add_single_inline_avx_2_works() {
340        let mut a = [unsafe { _mm256_setzero_si256() }; 64];
341        a[63] = unsafe { std::mem::transmute([!0u64; 4]) };
342        unsafe { bitsliced_add_single_inline_avx_2(&mut a, 1) };
343        let mut arr = [0u64; 8];
344        unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[63]) };
345        assert_eq!(arr[0], 0);
346        unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[62]) };
347        assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
348        for i in 0..62 {
349            unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[i]) };
350            assert_eq!(arr[0], 0);
351        }
352    }
353
354    #[test]
355    fn test_modulo_works() {
356        let a = [ALL_ONES; 64];
357        let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
358        for i in 0..8 {
359            assert_eq!(res[i], ZERO);
360        }
361        for i in 8..64 {
362            assert_eq!(res[i], ALL_ONES);
363        }
364    }
365
366    #[test]
367    fn test_modulo_inline_works() {
368        let mut a = [ALL_ONES; 64];
369        let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
370        for i in 0..8 {
371            assert_eq!(a[i], ZERO);
372        }
373        for i in 8..64 {
374            assert_eq!(a[i], ALL_ONES);
375        }
376    }
377
378    #[test]
379    fn test_modulo_inline_avx_2_works() {
380        unsafe {
381            let mut a = [M2_ONES; 64];
382            let _ = bitsliced_modulo_power_of_two_inline_avx_2(&mut a, 56).unwrap();
383
384            let zero_raw: [u8; 32] = std::mem::transmute(M2_ZERO);
385            let ones_raw: [u8; 32] = std::mem::transmute(M2_ONES);
386
387            for i in 0..8 {
388                let actual: [u8; 32] = std::mem::transmute(a[i]);
389                assert_eq!(actual, zero_raw, "Index {} should be ZERO", i);
390            }
391            for i in 8..64 {
392                let actual: [u8; 32] = std::mem::transmute(a[i]);
393                assert_eq!(actual, ones_raw, "Index {} should be ONES", i);
394            }
395        }
396    }
397}