1use std::{
2 arch::x86_64::{
3 __m512i, _mm512_and_si512, _mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512,
4 _mm512_xor_si512,
5 },
6 io::{Error, ErrorKind},
7 sync::OnceLock,
8};
9
10use wide::u64x8;
11
12use crate::transpose::transpose_scalar;
13
14pub mod benchmark;
15pub mod transpose;
16
17pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
18pub const ZERO: u64x8 = u64x8::ZERO;
19
20pub fn splat(n: u64) -> u64x8 {
21 u64x8::splat(n)
22}
23
24pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
27 let mut carry = u64x8::ZERO;
28 let mut sum = [u64x8::ZERO; 64];
29 for i in (0..64).rev() {
30 let res = calc_sum_carry(a[i], b[i], carry);
31 sum[i] = res.0;
32 carry = res.1;
34 }
35 sum
36}
37
38pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
39 let mut carry = u64x8::ZERO;
40 let mut sum = [u64x8::ZERO; 64];
41 for i in (0..64).rev() {
42 let shift_right = 63 - i;
43 let current_bit = (b >> shift_right) & 1;
44 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
45 let res = calc_sum_carry(a[i], b_i, carry);
46 sum[i] = res.0;
47 carry = res.1;
49 }
50 sum
51}
52
53pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
54 let mut carry = u64x8::ZERO;
55 for i in (0..64).rev() {
56 let res = calc_sum_carry(a[i], b[i], carry);
57 a[i] = res.0;
58 carry = res.1;
60 }
61}
62
63pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
64 let mut carry = u64x8::ZERO;
65 for i in (0..64).rev() {
66 let shift_right = 63 - i;
67 let current_bit = (b >> shift_right) & 1;
68 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
69 let res = calc_sum_carry(a[i], b_i, carry);
70 a[i] = res.0;
71 carry = res.1;
73 }
74}
75
76fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
77 let sum = a ^ b ^ carry;
78 let next_carry = (a & b) | (carry & (a ^ b));
79 (sum, next_carry)
80}
81
82const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
83
84const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
85
86pub unsafe fn bitsliced_add_single_inline_avx(a: &mut [__m512i; 64], b: u64) {
87 let mut carry = _mm512_setzero_si512();
88 for i in (0..64).rev() {
89 let shift_right = 63 - i;
90 let current_bit = (b >> shift_right) & 1;
91 let b_i = if current_bit == 1 {
92 M512_ONES
93 } else {
94 M512_ZERO
95 };
96 let res = calc_sum_carry_avx(a[i], b_i, carry);
97 a[i] = res.0;
98 carry = res.1;
100 }
101}
102
103unsafe fn calc_sum_carry_avx(a: __m512i, b: __m512i, carry: __m512i) -> (__m512i, __m512i) {
104 let sum = _mm512_xor_si512(_mm512_xor_si512(a, b), carry);
105 let axb = _mm512_xor_si512(a, b);
106 let a_and_b = _mm512_and_si512(a, b);
107 let carry_term = _mm512_and_si512(carry, axb);
108 let next_carry = _mm512_or_si512(a_and_b, carry_term);
109 (sum, next_carry)
110}
111
112pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
116 if k > 64 {
117 return Err(Error::new(
118 ErrorKind::InvalidData,
119 "k must be <= 64 for bitsliced modulo",
120 ));
121 }
122 let mut out = [u64x8::splat(0); 64];
123 let start: usize = 64 - k;
124 out[start..].copy_from_slice(&a[start..]);
125
126 Ok(out)
127}
128
129pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
130 if k > 64 {
131 return Err(Error::new(
132 ErrorKind::InvalidData,
133 "k must be <= 64 for bitsliced modulo",
134 ));
135 }
136 let end: usize = 64 - k;
137 for i in 0..end {
138 a[i] = u64x8::splat(0);
139 }
140
141 Ok(())
142}
143
144pub fn bitsliced_modulo_power_of_two_inline_avx(
145 a: &mut [__m512i; 64],
146 k: usize,
147) -> Result<(), Error> {
148 if k > 64 {
149 return Err(Error::new(
150 ErrorKind::InvalidData,
151 "k must be <= 64 for bitsliced modulo",
152 ));
153 }
154 let end: usize = 64 - k;
155 for i in 0..end {
156 a[i] = M512_ZERO
157 }
158
159 Ok(())
160}
161
162pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
165 let mut sum = bitsliced_add_single(h, i);
166 bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
167 sum
168}
169
170pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
171 bitsliced_add_single_inline(h, i);
172 bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
173}
174
175pub unsafe fn des_reduction_inline_avx(h: &mut [__m512i; 64], i: u64) {
176 unsafe { bitsliced_add_single_inline_avx(h, i) };
177 bitsliced_modulo_power_of_two_inline_avx(h, 56).unwrap();
178}
179
180static USE_GFNI: OnceLock<bool> = OnceLock::new();
181
182pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
185 if *USE_GFNI.get_or_init(|| {
186 #[cfg(target_arch = "x86_64")]
187 {
188 std::is_x86_feature_detected!("gfni")
189 && std::is_x86_feature_detected!("avx512f")
190 && std::is_x86_feature_detected!("avx512bw")
191 && std::is_x86_feature_detected!("avx512vbmi")
192 }
193 #[cfg(not(target_arch = "x86_64"))]
194 {
195 false
196 }
197 }) {
198 unsafe { crate::transpose::transpose_gfni(input) }
199 } else {
200 transpose_scalar(input)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_add_works() {
210 let mut a = [ZERO; 64];
211 a[63] = ALL_ONES;
212 let mut b = [ZERO; 64];
213 b[63] = ALL_ONES;
214 let sum = bitsliced_add(&a, &b);
215 assert_eq!(sum[63], ZERO);
216 assert_eq!(sum[62], ALL_ONES);
217 for i in 0..62 {
218 assert_eq!(sum[i], ZERO);
219 }
220 }
221
222 #[test]
223 fn test_add_single_works() {
224 let mut a = [ZERO; 64];
225 a[63] = ALL_ONES;
226 let sum = bitsliced_add_single(&a, 1);
227 assert_eq!(sum[63], ZERO);
228 assert_eq!(sum[62], ALL_ONES);
229 for i in 0..62 {
230 assert_eq!(sum[i], ZERO);
231 }
232 }
233
234 #[test]
235 fn test_add_inline_works() {
236 let mut a = [ZERO; 64];
237 a[63] = ALL_ONES;
238 let mut b = [ZERO; 64];
239 b[63] = ALL_ONES;
240 bitsliced_add_inline(&mut a, &b);
241 assert_eq!(a[63], ZERO);
242 assert_eq!(a[62], ALL_ONES);
243 for i in 0..62 {
244 assert_eq!(a[i], ZERO);
245 }
246 }
247
248 #[test]
249 fn test_add_single_inline_works() {
250 let mut a = [ZERO; 64];
251 a[63] = ALL_ONES;
252 bitsliced_add_single_inline(&mut a, 1);
253 assert_eq!(a[63], ZERO);
254 assert_eq!(a[62], ALL_ONES);
255 for i in 0..62 {
256 assert_eq!(a[i], ZERO);
257 }
258 }
259
260 #[test]
261 fn test_modulo_works() {
262 let a = [ALL_ONES; 64];
263 let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
264 for i in 0..8 {
265 assert_eq!(res[i], ZERO);
266 }
267 for i in 8..64 {
268 assert_eq!(res[i], ALL_ONES);
269 }
270 }
271
272 #[test]
273 fn test_modulo_inline_works() {
274 let mut a = [ALL_ONES; 64];
275 let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
276 for i in 0..8 {
277 assert_eq!(a[i], ZERO);
278 }
279 for i in 8..64 {
280 assert_eq!(a[i], ALL_ONES);
281 }
282 }
283}