1use std::{
2 io::{Error, ErrorKind},
3 sync::OnceLock,
4};
5
6use wide::u64x8;
7
8use crate::transpose::transpose_scalar;
9
10pub mod benchmark;
11pub mod transpose;
12
13pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
14pub const ZERO: u64x8 = u64x8::ZERO;
15
16pub fn splat(n: u64) -> u64x8 {
17 u64x8::splat(n)
18}
19
20pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
23 let mut carry = u64x8::ZERO;
24 let mut sum = [u64x8::ZERO; 64];
25 for i in (0..64).rev() {
26 let res = calc_sum_carry(a[i], b[i], carry);
27 sum[i] = res.0;
28 carry = res.1;
30 }
31 sum
32}
33
34pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
35 let mut carry = u64x8::ZERO;
36 let mut sum = [u64x8::ZERO; 64];
37 for i in (0..64).rev() {
38 let shift_right = 63 - i;
39 let current_bit = (b >> shift_right) & 1;
40 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
41 let res = calc_sum_carry(a[i], b_i, carry);
42 sum[i] = res.0;
43 carry = res.1;
45 }
46 sum
47}
48
49pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
50 let mut carry = u64x8::ZERO;
51 for i in (0..64).rev() {
52 let res = calc_sum_carry(a[i], b[i], carry);
53 a[i] = res.0;
54 carry = res.1;
56 }
57}
58
59pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
60 let mut carry = u64x8::ZERO;
61 for i in (0..64).rev() {
62 let shift_right = 63 - i;
63 let current_bit = (b >> shift_right) & 1;
64 let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
65 let res = calc_sum_carry(a[i], b_i, carry);
66 a[i] = res.0;
67 carry = res.1;
69 }
70}
71
72fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
73 let sum = a ^ b ^ carry;
74 let next_carry = (a & b) | (carry & (a ^ b));
75 (sum, next_carry)
76}
77
78pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
82 if k > 64 {
83 return Err(Error::new(
84 ErrorKind::InvalidData,
85 "k must be <= 64 for bitsliced modulo",
86 ));
87 }
88 let mut out = [u64x8::splat(0); 64];
89 let start: usize = 64 - k;
90 out[start..].copy_from_slice(&a[start..]);
91
92 Ok(out)
93}
94
95pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
96 if k > 64 {
97 return Err(Error::new(
98 ErrorKind::InvalidData,
99 "k must be <= 64 for bitsliced modulo",
100 ));
101 }
102 let end: usize = 64 - k;
103 for i in 0..end {
104 a[i] = u64x8::splat(0);
105 }
106
107 Ok(())
108}
109
110pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
113 let mut sum = bitsliced_add_single(h, i);
114 bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
115 sum
116}
117
118pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
119 bitsliced_add_single_inline(h, i);
120 bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
121}
122
123static USE_GFNI: OnceLock<bool> = OnceLock::new();
124
125pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
128 if *USE_GFNI.get_or_init(|| {
129 #[cfg(target_arch = "x86_64")]
130 {
131 std::is_x86_feature_detected!("gfni")
132 && std::is_x86_feature_detected!("avx512f")
133 && std::is_x86_feature_detected!("avx512bw")
134 && std::is_x86_feature_detected!("avx512vbmi")
135 }
136 #[cfg(not(target_arch = "x86_64"))]
137 {
138 false
139 }
140 }) {
141 unsafe { crate::transpose::transpose_gfni(input) }
142 } else {
143 transpose_scalar(input)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_add_works() {
153 let mut a = [ZERO; 64];
154 a[63] = ALL_ONES;
155 let mut b = [ZERO; 64];
156 b[63] = ALL_ONES;
157 let sum = bitsliced_add(&a, &b);
158 assert_eq!(sum[63], ZERO);
159 assert_eq!(sum[62], ALL_ONES);
160 for i in 0..62 {
161 assert_eq!(sum[i], ZERO);
162 }
163 }
164
165 #[test]
166 fn test_add_single_works() {
167 let mut a = [ZERO; 64];
168 a[63] = ALL_ONES;
169 let sum = bitsliced_add_single(&a, 1);
170 assert_eq!(sum[63], ZERO);
171 assert_eq!(sum[62], ALL_ONES);
172 for i in 0..62 {
173 assert_eq!(sum[i], ZERO);
174 }
175 }
176
177 #[test]
178 fn test_add_inline_works() {
179 let mut a = [ZERO; 64];
180 a[63] = ALL_ONES;
181 let mut b = [ZERO; 64];
182 b[63] = ALL_ONES;
183 bitsliced_add_inline(&mut a, &b);
184 assert_eq!(a[63], ZERO);
185 assert_eq!(a[62], ALL_ONES);
186 for i in 0..62 {
187 assert_eq!(a[i], ZERO);
188 }
189 }
190
191 #[test]
192 fn test_add_single_inline_works() {
193 let mut a = [ZERO; 64];
194 a[63] = ALL_ONES;
195 bitsliced_add_single_inline(&mut a, 1);
196 assert_eq!(a[63], ZERO);
197 assert_eq!(a[62], ALL_ONES);
198 for i in 0..62 {
199 assert_eq!(a[i], ZERO);
200 }
201 }
202
203 #[test]
204 fn test_modulo_works() {
205 let a = [ALL_ONES; 64];
206 let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
207 for i in 0..8 {
208 assert_eq!(res[i], ZERO);
209 }
210 for i in 8..64 {
211 assert_eq!(res[i], ALL_ONES);
212 }
213 }
214
215 #[test]
216 fn test_modulo_inline_works() {
217 let mut a = [ALL_ONES; 64];
218 let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
219 for i in 0..8 {
220 assert_eq!(a[i], ZERO);
221 }
222 for i in 8..64 {
223 assert_eq!(a[i], ALL_ONES);
224 }
225 }
226}