Skip to main content

bitsliced_op/
lib.rs

1use std::{
2    arch::x86_64::{
3        __m512i, _mm512_and_si512, _mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512,
4        _mm512_xor_si512,
5    },
6    io::{Error, ErrorKind},
7    sync::OnceLock,
8};
9
10use wide::u64x8;
11
12use crate::transpose::transpose_scalar;
13
14pub mod benchmark;
15pub mod transpose;
16
17pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
18pub const ZERO: u64x8 = u64x8::ZERO;
19
20pub fn splat(n: u64) -> u64x8 {
21    u64x8::splat(n)
22}
23
24//expects the input to be in bitsliced form e.g integers are columns, not rows
25//last row is LSB
26pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
27    let mut carry = u64x8::ZERO;
28    let mut sum = [u64x8::ZERO; 64];
29    for i in (0..64).rev() {
30        let res = calc_sum_carry(a[i], b[i], carry);
31        sum[i] = res.0;
32        //only set carry if we haven't reached the end yet, we currently ignore overflows
33        carry = res.1;
34    }
35    sum
36}
37
38pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
39    let mut carry = u64x8::ZERO;
40    let mut sum = [u64x8::ZERO; 64];
41    for i in (0..64).rev() {
42        let shift_right = 63 - i;
43        let current_bit = (b >> shift_right) & 1;
44        let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
45        let res = calc_sum_carry(a[i], b_i, carry);
46        sum[i] = res.0;
47        //only set carry if we haven't reached the end yet, we currently ignore overflows
48        carry = res.1;
49    }
50    sum
51}
52
53pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
54    let mut carry = u64x8::ZERO;
55    for i in (0..64).rev() {
56        let res = calc_sum_carry(a[i], b[i], carry);
57        a[i] = res.0;
58        //only set carry if we haven't reached the end yet, we currently ignore overflows
59        carry = res.1;
60    }
61}
62
63pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
64    let mut carry = u64x8::ZERO;
65    for i in (0..64).rev() {
66        let shift_right = 63 - i;
67        let current_bit = (b >> shift_right) & 1;
68        let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
69        let res = calc_sum_carry(a[i], b_i, carry);
70        a[i] = res.0;
71        //only set carry if we haven't reached the end yet, we currently ignore overflows
72        carry = res.1;
73    }
74}
75
76fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
77    let sum = a ^ b ^ carry;
78    let next_carry = (a & b) | (carry & (a ^ b));
79    (sum, next_carry)
80}
81
82const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
83
84const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
85
86pub unsafe fn bitsliced_add_single_inline_avx(a: &mut [__m512i; 64], b: u64) {
87    let mut carry = _mm512_setzero_si512();
88    for i in (0..64).rev() {
89        let shift_right = 63 - i;
90        let current_bit = (b >> shift_right) & 1;
91        let b_i = if current_bit == 1 {
92            M512_ONES
93        } else {
94            M512_ZERO
95        };
96        let res = calc_sum_carry_avx(a[i], b_i, carry);
97        a[i] = res.0;
98        //only set carry if we haven't reached the end yet, we currently ignore overflows
99        carry = res.1;
100    }
101}
102
103unsafe fn calc_sum_carry_avx(a: __m512i, b: __m512i, carry: __m512i) -> (__m512i, __m512i) {
104    let sum = _mm512_xor_si512(_mm512_xor_si512(a, b), carry);
105    let axb = _mm512_xor_si512(a, b);
106    let a_and_b = _mm512_and_si512(a, b);
107    let carry_term = _mm512_and_si512(carry, axb);
108    let next_carry = _mm512_or_si512(a_and_b, carry_term);
109    (sum, next_carry)
110}
111
112//this function only works when calculating the module with a number of the power of two
113//currently only supports a single modulo operation for all integers
114//example: if you want to calculate the modulo with 2^56, pass 56 to k
115pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
116    if k > 64 {
117        return Err(Error::new(
118            ErrorKind::InvalidData,
119            "k must be <= 64 for bitsliced modulo",
120        ));
121    }
122    let mut out = [u64x8::splat(0); 64];
123    let start: usize = 64 - k;
124    out[start..].copy_from_slice(&a[start..]);
125
126    Ok(out)
127}
128
129pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
130    if k > 64 {
131        return Err(Error::new(
132            ErrorKind::InvalidData,
133            "k must be <= 64 for bitsliced modulo",
134        ));
135    }
136    let end: usize = 64 - k;
137    for i in 0..end {
138        a[i] = u64x8::splat(0);
139    }
140
141    Ok(())
142}
143
144pub fn bitsliced_modulo_power_of_two_inline_avx(
145    a: &mut [__m512i; 64],
146    k: usize,
147) -> Result<(), Error> {
148    if k > 64 {
149        return Err(Error::new(
150            ErrorKind::InvalidData,
151            "k must be <= 64 for bitsliced modulo",
152        ));
153    }
154    let end: usize = 64 - k;
155    for i in 0..end {
156        a[i] = M512_ZERO
157    }
158
159    Ok(())
160}
161
162//reduction function: (H+I)%MAX_SIZE
163//H=Hash,I=Index in chain,MAX_SIZE=Max size of output in power of 2
164pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
165    let mut sum = bitsliced_add_single(h, i);
166    bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
167    sum
168}
169
170pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
171    bitsliced_add_single_inline(h, i);
172    bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
173}
174
175pub unsafe fn des_reduction_inline_avx(h: &mut [__m512i; 64], i: u64) {
176    unsafe { bitsliced_add_single_inline_avx(h, i) };
177    bitsliced_modulo_power_of_two_inline_avx(h, 56).unwrap();
178}
179
180static USE_GFNI: OnceLock<bool> = OnceLock::new();
181
182//transpose 64x64 bit matrix
183//use gfni if the cpu supports it, fallback to scalar if it doesn't
184pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
185    if *USE_GFNI.get_or_init(|| {
186        #[cfg(target_arch = "x86_64")]
187        {
188            std::is_x86_feature_detected!("gfni")
189                && std::is_x86_feature_detected!("avx512f")
190                && std::is_x86_feature_detected!("avx512bw")
191                && std::is_x86_feature_detected!("avx512vbmi")
192        }
193        #[cfg(not(target_arch = "x86_64"))]
194        {
195            false
196        }
197    }) {
198        unsafe { crate::transpose::transpose_gfni(input) }
199    } else {
200        transpose_scalar(input)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_add_works() {
210        let mut a = [ZERO; 64];
211        a[63] = ALL_ONES;
212        let mut b = [ZERO; 64];
213        b[63] = ALL_ONES;
214        let sum = bitsliced_add(&a, &b);
215        assert_eq!(sum[63], ZERO);
216        assert_eq!(sum[62], ALL_ONES);
217        for i in 0..62 {
218            assert_eq!(sum[i], ZERO);
219        }
220    }
221
222    #[test]
223    fn test_add_single_works() {
224        let mut a = [ZERO; 64];
225        a[63] = ALL_ONES;
226        let sum = bitsliced_add_single(&a, 1);
227        assert_eq!(sum[63], ZERO);
228        assert_eq!(sum[62], ALL_ONES);
229        for i in 0..62 {
230            assert_eq!(sum[i], ZERO);
231        }
232    }
233
234    #[test]
235    fn test_add_inline_works() {
236        let mut a = [ZERO; 64];
237        a[63] = ALL_ONES;
238        let mut b = [ZERO; 64];
239        b[63] = ALL_ONES;
240        bitsliced_add_inline(&mut a, &b);
241        assert_eq!(a[63], ZERO);
242        assert_eq!(a[62], ALL_ONES);
243        for i in 0..62 {
244            assert_eq!(a[i], ZERO);
245        }
246    }
247
248    #[test]
249    fn test_add_single_inline_works() {
250        let mut a = [ZERO; 64];
251        a[63] = ALL_ONES;
252        bitsliced_add_single_inline(&mut a, 1);
253        assert_eq!(a[63], ZERO);
254        assert_eq!(a[62], ALL_ONES);
255        for i in 0..62 {
256            assert_eq!(a[i], ZERO);
257        }
258    }
259
260    #[test]
261    fn test_modulo_works() {
262        let a = [ALL_ONES; 64];
263        let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
264        for i in 0..8 {
265            assert_eq!(res[i], ZERO);
266        }
267        for i in 8..64 {
268            assert_eq!(res[i], ALL_ONES);
269        }
270    }
271
272    #[test]
273    fn test_modulo_inline_works() {
274        let mut a = [ALL_ONES; 64];
275        let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
276        for i in 0..8 {
277            assert_eq!(a[i], ZERO);
278        }
279        for i in 8..64 {
280            assert_eq!(a[i], ALL_ONES);
281        }
282    }
283}