m61_modulus/simd/
avx2.rs

1#[cfg(target_arch = "x86")]
2use core::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use core::arch::x86_64::*;
5
6use crate::definition::{final_reduction, M61, MODULUS};
7
8#[target_feature(enable = "avx2")]
9unsafe fn reduction_core(ptr: *const __m256i, mut len: usize, mut hi: __m256i) -> M61 {
10    let mlo = _mm256_set1_epi64x(MODULUS as i64);
11    let mhi = _mm256_set1_epi64x((MODULUS >> 12) as i64);
12
13    // Initial reduction of high elements.
14    hi = _mm256_add_epi64(_mm256_and_si256(hi, mlo), _mm256_srli_epi64::<61>(hi));
15
16    while len > 0 {
17        len -= 1;
18
19        let lo = ptr.add(len).read_unaligned();
20        let lr = _mm256_add_epi64(_mm256_and_si256(mlo, lo), _mm256_srli_epi64::<61>(lo));
21        let hr = _mm256_add_epi64(
22            _mm256_slli_epi64::<12>(_mm256_and_si256(hi, mhi)),
23            _mm256_srli_epi64::<49>(hi),
24        );
25        hi = _mm256_add_epi64(lr, hr);
26    }
27
28    // One reduction step using 128-bit operands
29    // halves the problem size.
30
31    let lo = _mm256_castsi256_si128(hi);
32    let mut hi = _mm256_extracti128_si256::<1>(hi);
33
34    let mlo = _mm_set1_epi64x(MODULUS as i64);
35    let mhi = _mm_set1_epi64x((MODULUS >> 6) as i64);
36
37    let lr = _mm_add_epi64(_mm_and_si128(mlo, lo), _mm_srli_epi64::<61>(lo));
38    let hr = _mm_add_epi64(
39        _mm_slli_epi64::<6>(_mm_and_si128(hi, mhi)),
40        _mm_srli_epi64::<55>(hi),
41    );
42    hi = _mm_add_epi64(lr, hr);
43
44    // Last reduction step done using scalar operaions.
45
46    let lo = _mm_cvtsi128_si64x(hi) as u64;
47    let mut hi = _mm_extract_epi64::<1>(hi) as u64;
48
49    hi = (lo & MODULUS) + (lo >> 61) + ((hi & (MODULUS >> 3)) << 3) + (hi >> 58);
50
51    final_reduction(hi)
52}
53
54#[target_feature(enable = "avx2")]
55pub unsafe fn reduce_u8(s: &[u8]) -> M61 {
56    let hi = if s.len() & 31 != 0 {
57        let mut lo = _mm_setzero_si128();
58        let mut hi = _mm_setzero_si128();
59
60        let l = s.len() & !31;
61        let mut ptr = s.as_ptr().add(l);
62
63        if s.len() & 16 != 0 {
64            lo = (ptr as *const __m128i).read_unaligned();
65            ptr = ptr.add(16);
66        }
67
68        let mut tmp = _mm_setzero_si128();
69        for i in (0..(s.len() & 15)).rev() {
70            tmp = _mm_bslli_si128::<1>(tmp);
71            tmp = _mm_insert_epi8::<0>(tmp, *ptr.add(i) as i32);
72        }
73
74        if s.len() & 16 != 0 {
75            hi = tmp;
76        } else {
77            lo = tmp;
78        }
79
80        _mm256_set_m128i(hi, lo)
81    } else {
82        _mm256_setzero_si256()
83    };
84
85    reduction_core(s.as_ptr() as *const __m256i, s.len() >> 5, hi)
86}
87
88#[target_feature(enable = "avx2")]
89pub unsafe fn reduce_u16(s: &[u16]) -> M61 {
90    let hi = if s.len() & 15 != 0 {
91        let mut lo = _mm_setzero_si128();
92        let mut hi = _mm_setzero_si128();
93
94        let l = s.len() & !15;
95        let mut ptr = s.as_ptr().add(l);
96
97        if s.len() & 8 != 0 {
98            lo = (ptr as *const __m128i).read_unaligned();
99            ptr = ptr.add(8);
100        }
101
102        let mut tmp = _mm_setzero_si128();
103        for i in (0..(s.len() & 7)).rev() {
104            tmp = _mm_bslli_si128::<2>(tmp);
105            tmp = _mm_insert_epi16::<0>(tmp, *ptr.add(i) as i32);
106        }
107
108        if s.len() & 8 != 0 {
109            hi = tmp;
110        } else {
111            lo = tmp;
112        }
113
114        _mm256_set_m128i(hi, lo)
115    } else {
116        _mm256_setzero_si256()
117    };
118
119    reduction_core(s.as_ptr() as *const __m256i, s.len() >> 4, hi)
120}
121
122#[target_feature(enable = "avx2")]
123pub unsafe fn reduce_u32(s: &[u32]) -> M61 {
124    let hi = if s.len() & 7 != 0 {
125        let mut arr = [0; 8];
126        let l = s.len() & !7;
127
128        for i in l..s.len() {
129            arr[i - l] = *s.get_unchecked(i);
130        }
131
132        (arr.as_ptr() as *const __m256i).read_unaligned()
133    } else {
134        _mm256_setzero_si256()
135    };
136
137    reduction_core(s.as_ptr() as *const __m256i, s.len() >> 3, hi)
138}
139
140#[target_feature(enable = "avx2")]
141pub unsafe fn reduce_u64(s: &[u64]) -> M61 {
142    let hi = if s.len() & 3 != 0 {
143        let mut arr = [0; 4];
144        let l = s.len() & !3;
145
146        for i in l..s.len() {
147            arr[i - l] = *s.get_unchecked(i);
148        }
149
150        (arr.as_ptr() as *const __m256i).read_unaligned()
151    } else {
152        _mm256_setzero_si256()
153    };
154
155    reduction_core(s.as_ptr() as *const __m256i, s.len() >> 2, hi)
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn reduce_u8_max() {
164        if !std::arch::is_x86_feature_detected!("avx2") {
165            return;
166        }
167
168        for len in 0..1000 {
169            let vec = vec![u8::MAX; len];
170
171            let expected = crate::fallback::reduce_u8(&vec);
172            let actual = unsafe { reduce_u8(&vec) };
173            assert_eq!(
174                expected, actual,
175                "expected: {expected:x}, actual: {actual:x}"
176            );
177        }
178    }
179
180    #[test]
181    fn reduce_u16_max() {
182        if !std::arch::is_x86_feature_detected!("avx2") {
183            return;
184        }
185
186        for len in 0..1000 {
187            let vec = vec![u16::MAX; len];
188
189            let expected = crate::fallback::reduce_u16(&vec);
190            let actual = unsafe { reduce_u16(&vec) };
191            assert_eq!(
192                expected, actual,
193                "expected: {expected:x}, actual: {actual:x}"
194            );
195        }
196    }
197
198    #[test]
199    fn reduce_u32_max() {
200        if !std::arch::is_x86_feature_detected!("avx2") {
201            return;
202        }
203
204        for len in 0..1000 {
205            let vec = vec![u32::MAX; len];
206
207            let expected = crate::fallback::reduce_u32(&vec);
208            let actual = unsafe { reduce_u32(&vec) };
209            assert_eq!(
210                expected, actual,
211                "expected: {expected:x}, actual: {actual:x}"
212            );
213        }
214    }
215
216    #[test]
217    fn reduce_u64_max() {
218        if !std::arch::is_x86_feature_detected!("avx2") {
219            return;
220        }
221
222        for len in 0..1000 {
223            let vec = vec![u64::MAX; len];
224
225            let expected = crate::fallback::reduce_u64(&vec);
226            let actual = unsafe { reduce_u64(&vec) };
227            assert_eq!(
228                expected, actual,
229                "expected: {expected:x}, actual: {actual:x}"
230            );
231        }
232    }
233
234    quickcheck::quickcheck! {
235        fn reduce_u8_correct(slice: Vec<u8>) -> bool {
236            if !std::arch::is_x86_feature_detected!("avx2") {
237                return true;
238            }
239
240            let expected = crate::fallback::reduce_u8(&slice);
241            let actual = unsafe { reduce_u8(&slice) };
242            expected == actual
243        }
244
245        fn reduce_u16_correct(slice: Vec<u16>) -> bool {
246            if !std::arch::is_x86_feature_detected!("avx2") {
247                return true;
248            }
249
250            let expected = crate::fallback::reduce_u16(&slice);
251            let actual = unsafe { reduce_u16(&slice) };
252            expected == actual
253        }
254
255        fn reduce_u32_correct(slice: Vec<u32>) -> bool {
256            if !std::arch::is_x86_feature_detected!("avx2") {
257                return true;
258            }
259
260            let expected = crate::fallback::reduce_u32(&slice);
261            let actual = unsafe { reduce_u32(&slice) };
262            expected == actual
263        }
264
265        fn reduce_u64_correct(slice: Vec<u64>) -> bool {
266            if !std::arch::is_x86_feature_detected!("avx2") {
267                return true;
268            }
269
270            let expected = crate::fallback::reduce_u64(&slice);
271            let actual = unsafe { reduce_u64(&slice) };
272            expected == actual
273        }
274    }
275}