Skip to main content

bitsliced_op/
lib.rs

1use std::{
2    arch::x86_64::{
3        __m512i, _mm512_and_si512, _mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512,
4        _mm512_ternarylogic_epi32, _mm512_test_epi64_mask, _mm512_xor_si512,
5    },
6    io::{Error, ErrorKind},
7    sync::OnceLock,
8};
9
10use wide::u64x8;
11
12use crate::transpose::transpose_scalar;
13
14pub mod benchmark;
15pub mod transpose;
16
17pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
18pub const ZERO: u64x8 = u64x8::ZERO;
19
20pub fn splat(n: u64) -> u64x8 {
21    u64x8::splat(n)
22}
23
24//expects the input to be in bitsliced form e.g integers are columns, not rows
25//last row is LSB
26pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
27    let mut carry = u64x8::ZERO;
28    let mut sum = [u64x8::ZERO; 64];
29    for i in (0..64).rev() {
30        let res = calc_sum_carry(a[i], b[i], carry);
31        sum[i] = res.0;
32        //only set carry if we haven't reached the end yet, we currently ignore overflows
33        carry = res.1;
34    }
35    sum
36}
37
38pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
39    let mut carry = u64x8::ZERO;
40    let mut sum = [u64x8::ZERO; 64];
41    for i in (0..64).rev() {
42        let shift_right = 63 - i;
43        let current_bit = (b >> shift_right) & 1;
44        let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
45        let res = calc_sum_carry(a[i], b_i, carry);
46        sum[i] = res.0;
47        //only set carry if we haven't reached the end yet, we currently ignore overflows
48        carry = res.1;
49    }
50    sum
51}
52
53pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
54    let mut carry = u64x8::ZERO;
55    for i in (0..64).rev() {
56        let res = calc_sum_carry(a[i], b[i], carry);
57        a[i] = res.0;
58        //only set carry if we haven't reached the end yet, we currently ignore overflows
59        carry = res.1;
60    }
61}
62
63pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
64    let mut carry = u64x8::ZERO;
65    for i in (0..64).rev() {
66        let shift_right = 63 - i;
67        let current_bit = (b >> shift_right) & 1;
68        let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
69        let res = calc_sum_carry(a[i], b_i, carry);
70        a[i] = res.0;
71        //only set carry if we haven't reached the end yet, we currently ignore overflows
72        carry = res.1;
73    }
74}
75
76fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
77    let sum = a ^ b ^ carry;
78    let next_carry = (a & b) | (carry & (a ^ b));
79    (sum, next_carry)
80}
81
82const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
83
84const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
85
86pub unsafe fn bitsliced_add_single_inline_avx(a: &mut [__m512i; 64], b: u64) {
87    let mut carry = M512_ZERO;
88    let max_bit_pos = 64 - (b.leading_zeros() as usize);
89    for i in (0..64).rev() {
90        let bit_index = 63 - i;
91        //break early to save cpu cycles
92        if bit_index >= max_bit_pos {
93            if _mm512_test_epi64_mask(carry, carry) == 0 {
94                break;
95            }
96        }
97
98        let current_bit = if ((b >> bit_index) & 1) == 1 {
99            M512_ONES
100        } else {
101            M512_ZERO
102        };
103
104        let a_orig = a[i];
105
106        a[i] = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0x96);
107
108        carry = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0xE8);
109    }
110}
111
112//this function only works when calculating the module with a number of the power of two
113//currently only supports a single modulo operation for all integers
114//example: if you want to calculate the modulo with 2^56, pass 56 to k
115pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
116    if k > 64 {
117        return Err(Error::new(
118            ErrorKind::InvalidData,
119            "k must be <= 64 for bitsliced modulo",
120        ));
121    }
122    let mut out = [u64x8::splat(0); 64];
123    let start: usize = 64 - k;
124    out[start..].copy_from_slice(&a[start..]);
125
126    Ok(out)
127}
128
129pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
130    if k > 64 {
131        return Err(Error::new(
132            ErrorKind::InvalidData,
133            "k must be <= 64 for bitsliced modulo",
134        ));
135    }
136    let end: usize = 64 - k;
137    for i in 0..end {
138        a[i] = u64x8::splat(0);
139    }
140
141    Ok(())
142}
143
144pub fn bitsliced_modulo_power_of_two_inline_avx(
145    a: &mut [__m512i; 64],
146    k: usize,
147) -> Result<(), Error> {
148    if k > 64 {
149        return Err(Error::new(
150            ErrorKind::InvalidData,
151            "k must be <= 64 for bitsliced modulo",
152        ));
153    }
154    let end: usize = 64 - k;
155    for i in 0..end {
156        a[i] = M512_ZERO
157    }
158
159    Ok(())
160}
161
162//reduction function: (H+I)%MAX_SIZE
163//H=Hash,I=Index in chain,MAX_SIZE=Max size of output in power of 2
164pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
165    let mut sum = bitsliced_add_single(h, i);
166    bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
167    sum
168}
169
170pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
171    bitsliced_add_single_inline(h, i);
172    bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
173}
174
175pub unsafe fn des_reduction_inline_avx(h: &mut [__m512i; 64], i: u64) {
176    unsafe { bitsliced_add_single_inline_avx(h, i) };
177    bitsliced_modulo_power_of_two_inline_avx(h, 56).unwrap();
178}
179
180static USE_GFNI: OnceLock<bool> = OnceLock::new();
181
182//transpose 64x64 bit matrix
183//use gfni if the cpu supports it, fallback to scalar if it doesn't
184pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
185    if *USE_GFNI.get_or_init(|| {
186        #[cfg(target_arch = "x86_64")]
187        {
188            std::is_x86_feature_detected!("gfni")
189                && std::is_x86_feature_detected!("avx512f")
190                && std::is_x86_feature_detected!("avx512bw")
191                && std::is_x86_feature_detected!("avx512vbmi")
192        }
193        #[cfg(not(target_arch = "x86_64"))]
194        {
195            false
196        }
197    }) {
198        unsafe { crate::transpose::transpose_gfni(input) }
199    } else {
200        transpose_scalar(input)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::arch::x86_64::_mm512_storeu_si512;
207
208    use super::*;
209
210    #[test]
211    fn test_add_works() {
212        let mut a = [ZERO; 64];
213        a[63] = ALL_ONES;
214        let mut b = [ZERO; 64];
215        b[63] = ALL_ONES;
216        let sum = bitsliced_add(&a, &b);
217        assert_eq!(sum[63], ZERO);
218        assert_eq!(sum[62], ALL_ONES);
219        for i in 0..62 {
220            assert_eq!(sum[i], ZERO);
221        }
222    }
223
224    #[test]
225    fn test_add_single_works() {
226        let mut a = [ZERO; 64];
227        a[63] = ALL_ONES;
228        let sum = bitsliced_add_single(&a, 1);
229        assert_eq!(sum[63], ZERO);
230        assert_eq!(sum[62], ALL_ONES);
231        for i in 0..62 {
232            assert_eq!(sum[i], ZERO);
233        }
234    }
235
236    #[test]
237    fn test_add_inline_works() {
238        let mut a = [ZERO; 64];
239        a[63] = ALL_ONES;
240        let mut b = [ZERO; 64];
241        b[63] = ALL_ONES;
242        bitsliced_add_inline(&mut a, &b);
243        assert_eq!(a[63], ZERO);
244        assert_eq!(a[62], ALL_ONES);
245        for i in 0..62 {
246            assert_eq!(a[i], ZERO);
247        }
248    }
249
250    #[test]
251    fn test_add_single_inline_works() {
252        let mut a = [ZERO; 64];
253        a[63] = ALL_ONES;
254        bitsliced_add_single_inline(&mut a, 1);
255        assert_eq!(a[63], ZERO);
256        assert_eq!(a[62], ALL_ONES);
257        for i in 0..62 {
258            assert_eq!(a[i], ZERO);
259        }
260    }
261
262    #[test]
263    fn test_add_single_inline_avx_works() {
264        let mut a = [unsafe { _mm512_setzero_si512() }; 64];
265        a[63] = unsafe { std::mem::transmute([!0u64; 8]) };
266        unsafe { bitsliced_add_single_inline_avx(&mut a, 1) };
267        let mut arr = [0u64; 8];
268        unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[63]) };
269        assert_eq!(arr[0], 0);
270        unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[62]) };
271        assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
272        for i in 0..62 {
273            unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[i]) };
274            assert_eq!(arr[0], 0);
275        }
276    }
277
278    #[test]
279    fn test_modulo_works() {
280        let a = [ALL_ONES; 64];
281        let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
282        for i in 0..8 {
283            assert_eq!(res[i], ZERO);
284        }
285        for i in 8..64 {
286            assert_eq!(res[i], ALL_ONES);
287        }
288    }
289
290    #[test]
291    fn test_modulo_inline_works() {
292        let mut a = [ALL_ONES; 64];
293        let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
294        for i in 0..8 {
295            assert_eq!(a[i], ZERO);
296        }
297        for i in 8..64 {
298            assert_eq!(a[i], ALL_ONES);
299        }
300    }
301}