1#[cfg(target_arch = "x86")]
2use core::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use core::arch::x86_64::*;
5
6use crate::definition::{final_reduction, M61, MODULUS};
7
8#[target_feature(enable = "avx2")]
9unsafe fn reduction_core(ptr: *const __m256i, mut len: usize, mut hi: __m256i) -> M61 {
10 let mlo = _mm256_set1_epi64x(MODULUS as i64);
11 let mhi = _mm256_set1_epi64x((MODULUS >> 12) as i64);
12
13 hi = _mm256_add_epi64(_mm256_and_si256(hi, mlo), _mm256_srli_epi64::<61>(hi));
15
16 while len > 0 {
17 len -= 1;
18
19 let lo = ptr.add(len).read_unaligned();
20 let lr = _mm256_add_epi64(_mm256_and_si256(mlo, lo), _mm256_srli_epi64::<61>(lo));
21 let hr = _mm256_add_epi64(
22 _mm256_slli_epi64::<12>(_mm256_and_si256(hi, mhi)),
23 _mm256_srli_epi64::<49>(hi),
24 );
25 hi = _mm256_add_epi64(lr, hr);
26 }
27
28 let lo = _mm256_castsi256_si128(hi);
32 let mut hi = _mm256_extracti128_si256::<1>(hi);
33
34 let mlo = _mm_set1_epi64x(MODULUS as i64);
35 let mhi = _mm_set1_epi64x((MODULUS >> 6) as i64);
36
37 let lr = _mm_add_epi64(_mm_and_si128(mlo, lo), _mm_srli_epi64::<61>(lo));
38 let hr = _mm_add_epi64(
39 _mm_slli_epi64::<6>(_mm_and_si128(hi, mhi)),
40 _mm_srli_epi64::<55>(hi),
41 );
42 hi = _mm_add_epi64(lr, hr);
43
44 let lo = _mm_cvtsi128_si64x(hi) as u64;
47 let mut hi = _mm_extract_epi64::<1>(hi) as u64;
48
49 hi = (lo & MODULUS) + (lo >> 61) + ((hi & (MODULUS >> 3)) << 3) + (hi >> 58);
50
51 final_reduction(hi)
52}
53
54#[target_feature(enable = "avx2")]
55pub unsafe fn reduce_u8(s: &[u8]) -> M61 {
56 let hi = if s.len() & 31 != 0 {
57 let mut lo = _mm_setzero_si128();
58 let mut hi = _mm_setzero_si128();
59
60 let l = s.len() & !31;
61 let mut ptr = s.as_ptr().add(l);
62
63 if s.len() & 16 != 0 {
64 lo = (ptr as *const __m128i).read_unaligned();
65 ptr = ptr.add(16);
66 }
67
68 let mut tmp = _mm_setzero_si128();
69 for i in (0..(s.len() & 15)).rev() {
70 tmp = _mm_bslli_si128::<1>(tmp);
71 tmp = _mm_insert_epi8::<0>(tmp, *ptr.add(i) as i32);
72 }
73
74 if s.len() & 16 != 0 {
75 hi = tmp;
76 } else {
77 lo = tmp;
78 }
79
80 _mm256_set_m128i(hi, lo)
81 } else {
82 _mm256_setzero_si256()
83 };
84
85 reduction_core(s.as_ptr() as *const __m256i, s.len() >> 5, hi)
86}
87
88#[target_feature(enable = "avx2")]
89pub unsafe fn reduce_u16(s: &[u16]) -> M61 {
90 let hi = if s.len() & 15 != 0 {
91 let mut lo = _mm_setzero_si128();
92 let mut hi = _mm_setzero_si128();
93
94 let l = s.len() & !15;
95 let mut ptr = s.as_ptr().add(l);
96
97 if s.len() & 8 != 0 {
98 lo = (ptr as *const __m128i).read_unaligned();
99 ptr = ptr.add(8);
100 }
101
102 let mut tmp = _mm_setzero_si128();
103 for i in (0..(s.len() & 7)).rev() {
104 tmp = _mm_bslli_si128::<2>(tmp);
105 tmp = _mm_insert_epi16::<0>(tmp, *ptr.add(i) as i32);
106 }
107
108 if s.len() & 8 != 0 {
109 hi = tmp;
110 } else {
111 lo = tmp;
112 }
113
114 _mm256_set_m128i(hi, lo)
115 } else {
116 _mm256_setzero_si256()
117 };
118
119 reduction_core(s.as_ptr() as *const __m256i, s.len() >> 4, hi)
120}
121
122#[target_feature(enable = "avx2")]
123pub unsafe fn reduce_u32(s: &[u32]) -> M61 {
124 let hi = if s.len() & 7 != 0 {
125 let mut arr = [0; 8];
126 let l = s.len() & !7;
127
128 for i in l..s.len() {
129 arr[i - l] = *s.get_unchecked(i);
130 }
131
132 (arr.as_ptr() as *const __m256i).read_unaligned()
133 } else {
134 _mm256_setzero_si256()
135 };
136
137 reduction_core(s.as_ptr() as *const __m256i, s.len() >> 3, hi)
138}
139
140#[target_feature(enable = "avx2")]
141pub unsafe fn reduce_u64(s: &[u64]) -> M61 {
142 let hi = if s.len() & 3 != 0 {
143 let mut arr = [0; 4];
144 let l = s.len() & !3;
145
146 for i in l..s.len() {
147 arr[i - l] = *s.get_unchecked(i);
148 }
149
150 (arr.as_ptr() as *const __m256i).read_unaligned()
151 } else {
152 _mm256_setzero_si256()
153 };
154
155 reduction_core(s.as_ptr() as *const __m256i, s.len() >> 2, hi)
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn reduce_u8_max() {
164 if !std::arch::is_x86_feature_detected!("avx2") {
165 return;
166 }
167
168 for len in 0..1000 {
169 let vec = vec![u8::MAX; len];
170
171 let expected = crate::fallback::reduce_u8(&vec);
172 let actual = unsafe { reduce_u8(&vec) };
173 assert_eq!(
174 expected, actual,
175 "expected: {expected:x}, actual: {actual:x}"
176 );
177 }
178 }
179
180 #[test]
181 fn reduce_u16_max() {
182 if !std::arch::is_x86_feature_detected!("avx2") {
183 return;
184 }
185
186 for len in 0..1000 {
187 let vec = vec![u16::MAX; len];
188
189 let expected = crate::fallback::reduce_u16(&vec);
190 let actual = unsafe { reduce_u16(&vec) };
191 assert_eq!(
192 expected, actual,
193 "expected: {expected:x}, actual: {actual:x}"
194 );
195 }
196 }
197
198 #[test]
199 fn reduce_u32_max() {
200 if !std::arch::is_x86_feature_detected!("avx2") {
201 return;
202 }
203
204 for len in 0..1000 {
205 let vec = vec![u32::MAX; len];
206
207 let expected = crate::fallback::reduce_u32(&vec);
208 let actual = unsafe { reduce_u32(&vec) };
209 assert_eq!(
210 expected, actual,
211 "expected: {expected:x}, actual: {actual:x}"
212 );
213 }
214 }
215
216 #[test]
217 fn reduce_u64_max() {
218 if !std::arch::is_x86_feature_detected!("avx2") {
219 return;
220 }
221
222 for len in 0..1000 {
223 let vec = vec![u64::MAX; len];
224
225 let expected = crate::fallback::reduce_u64(&vec);
226 let actual = unsafe { reduce_u64(&vec) };
227 assert_eq!(
228 expected, actual,
229 "expected: {expected:x}, actual: {actual:x}"
230 );
231 }
232 }
233
234 quickcheck::quickcheck! {
235 fn reduce_u8_correct(slice: Vec<u8>) -> bool {
236 if !std::arch::is_x86_feature_detected!("avx2") {
237 return true;
238 }
239
240 let expected = crate::fallback::reduce_u8(&slice);
241 let actual = unsafe { reduce_u8(&slice) };
242 expected == actual
243 }
244
245 fn reduce_u16_correct(slice: Vec<u16>) -> bool {
246 if !std::arch::is_x86_feature_detected!("avx2") {
247 return true;
248 }
249
250 let expected = crate::fallback::reduce_u16(&slice);
251 let actual = unsafe { reduce_u16(&slice) };
252 expected == actual
253 }
254
255 fn reduce_u32_correct(slice: Vec<u32>) -> bool {
256 if !std::arch::is_x86_feature_detected!("avx2") {
257 return true;
258 }
259
260 let expected = crate::fallback::reduce_u32(&slice);
261 let actual = unsafe { reduce_u32(&slice) };
262 expected == actual
263 }
264
265 fn reduce_u64_correct(slice: Vec<u64>) -> bool {
266 if !std::arch::is_x86_feature_detected!("avx2") {
267 return true;
268 }
269
270 let expected = crate::fallback::reduce_u64(&slice);
271 let actual = unsafe { reduce_u64(&slice) };
272 expected == actual
273 }
274 }
275}