block_aligner/
avx2.rs

1#[cfg(target_arch = "x86")]
2use std::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6pub type Simd = __m256i; // use for storing DP scores
7pub type HalfSimd = __m128i; // used for storing bytes (sequence or scoring matrix)
8pub type LutSimd = __m128i; // used for storing a row in a scoring matrix (always 128 bits)
9pub type TraceType = i32;
10/// Number of 16-bit lanes in a SIMD vector.
11pub const L: usize = 16;
12pub const L_BYTES: usize = L * 2;
13pub const HALFSIMD_MUL: usize = 1;
14// using min = 0 is faster, but restricts range of scores (and restricts the max block size)
15pub const ZERO: i16 = 1 << 14;
16pub const MIN: i16 = 0;
17
18// Non-temporal store to avoid cluttering cache with traces
19// Actually, non-temporal stores are slower in benchmarks!
20#[target_feature(enable = "avx2")]
21#[inline]
22pub unsafe fn store_trace(ptr: *mut TraceType, trace: TraceType) { *ptr = trace; } // _mm_stream_si32(ptr, trace);
23
24#[target_feature(enable = "avx2")]
25#[inline]
26pub unsafe fn simd_adds_i16(a: Simd, b: Simd) -> Simd { _mm256_adds_epi16(a, b) }
27
28#[target_feature(enable = "avx2")]
29#[inline]
30pub unsafe fn simd_subs_i16(a: Simd, b: Simd) -> Simd { _mm256_subs_epi16(a, b) }
31
32#[target_feature(enable = "avx2")]
33#[inline]
34pub unsafe fn simd_max_i16(a: Simd, b: Simd) -> Simd { _mm256_max_epi16(a, b) }
35
36#[target_feature(enable = "avx2")]
37#[inline]
38pub unsafe fn simd_cmpeq_i16(a: Simd, b: Simd) -> Simd { _mm256_cmpeq_epi16(a, b) }
39
40#[target_feature(enable = "avx2")]
41#[inline]
42pub unsafe fn simd_cmpgt_i16(a: Simd, b: Simd) -> Simd { _mm256_cmpgt_epi16(a, b) }
43
44#[target_feature(enable = "avx2")]
45#[inline]
46pub unsafe fn simd_blend_i8(a: Simd, b: Simd, mask: Simd) -> Simd { _mm256_blendv_epi8(a, b, mask) }
47
48#[target_feature(enable = "avx2")]
49#[inline]
50pub unsafe fn simd_load(ptr: *const Simd) -> Simd { _mm256_load_si256(ptr) }
51
52#[target_feature(enable = "avx2")]
53#[inline]
54pub unsafe fn simd_loadu(ptr: *const Simd) -> Simd { _mm256_loadu_si256(ptr) }
55
56#[target_feature(enable = "avx2")]
57#[inline]
58pub unsafe fn simd_store(ptr: *mut Simd, a: Simd) { _mm256_store_si256(ptr, a) }
59
60#[target_feature(enable = "avx2")]
61#[inline]
62pub unsafe fn simd_set1_i16(v: i16) -> Simd { _mm256_set1_epi16(v) }
63
64#[macro_export]
65#[doc(hidden)]
66macro_rules! simd_extract_i16 {
67    ($a:expr, $num:expr) => {
68        {
69            debug_assert!($num < L);
70            #[cfg(target_arch = "x86")]
71            use std::arch::x86::*;
72            #[cfg(target_arch = "x86_64")]
73            use std::arch::x86_64::*;
74            _mm256_extract_epi16($a, $num as i32) as i16
75        }
76    };
77}
78
79#[macro_export]
80#[doc(hidden)]
81macro_rules! simd_insert_i16 {
82    ($a:expr, $v:expr, $num:expr) => {
83        {
84            debug_assert!($num < L);
85            #[cfg(target_arch = "x86")]
86            use std::arch::x86::*;
87            #[cfg(target_arch = "x86_64")]
88            use std::arch::x86_64::*;
89            _mm256_insert_epi16($a, $v, $num as i32)
90        }
91    };
92}
93
94#[target_feature(enable = "avx2")]
95#[inline]
96pub unsafe fn simd_movemask_i8(a: Simd) -> u32 { _mm256_movemask_epi8(a) as u32 }
97
98#[macro_export]
99#[doc(hidden)]
100macro_rules! simd_sl_i16 {
101    ($a:expr, $b:expr, $num:expr) => {
102        {
103            debug_assert!(2 * $num <= L);
104            #[cfg(target_arch = "x86")]
105            use std::arch::x86::*;
106            #[cfg(target_arch = "x86_64")]
107            use std::arch::x86_64::*;
108            if $num == L / 2 {
109                _mm256_permute2x128_si256($a, $b, 0x03)
110            } else {
111                _mm256_alignr_epi8($a, _mm256_permute2x128_si256($a, $b, 0x03), (L - (2 * $num)) as i32)
112            }
113        }
114    };
115}
116
117#[macro_export]
118#[doc(hidden)]
119macro_rules! simd_sr_i16 {
120    ($a:expr, $b:expr, $num:expr) => {
121        {
122            debug_assert!(2 * $num <= L);
123            #[cfg(target_arch = "x86")]
124            use std::arch::x86::*;
125            #[cfg(target_arch = "x86_64")]
126            use std::arch::x86_64::*;
127            if $num == L / 2 {
128                _mm256_permute2x128_si256($a, $b, 0x03)
129            } else {
130                _mm256_alignr_epi8(_mm256_permute2x128_si256($a, $b, 0x03), $b, (2 * $num) as i32)
131            }
132        }
133    };
134}
135
136// hardcoded to STEP = 8
137#[target_feature(enable = "avx2")]
138#[inline]
139pub unsafe fn simd_step(a: Simd, b: Simd) -> Simd {
140    _mm256_permute2x128_si256(a, b, 0x03)
141}
142
143#[target_feature(enable = "avx2")]
144#[inline]
145unsafe fn simd_sl_i128(a: Simd, b: Simd) -> Simd {
146    _mm256_permute2x128_si256(a, b, 0x03)
147}
148
149// shift in zeros
150macro_rules! simd_sllz_i16 {
151    ($a:expr, $num:expr) => {
152        {
153            debug_assert!(2 * $num < L);
154            #[cfg(target_arch = "x86")]
155            use std::arch::x86::*;
156            #[cfg(target_arch = "x86_64")]
157            use std::arch::x86_64::*;
158            _mm256_slli_si256($a, ($num * 2) as i32)
159        }
160    };
161}
162
163// broadcast last 16-bit element to the whole vector
164#[target_feature(enable = "avx2")]
165#[inline]
166pub unsafe fn simd_broadcasthi_i16(v: Simd) -> Simd {
167    let v = _mm256_shufflehi_epi16(v, 0b11111111);
168    _mm256_permute4x64_epi64(v, 0b11111111)
169}
170
171#[target_feature(enable = "avx2")]
172#[inline]
173pub unsafe fn simd_slow_extract_i16(v: Simd, i: usize) -> i16 {
174    debug_assert!(i < L);
175
176    #[repr(align(32))]
177    struct A([i16; L]);
178
179    let mut a = A([0i16; L]);
180    simd_store(a.0.as_mut_ptr() as *mut Simd, v);
181    *a.0.as_ptr().add(i)
182}
183
184#[target_feature(enable = "avx2")]
185#[inline]
186pub unsafe fn simd_hmax_i16(v: Simd) -> i16 {
187    let mut v2 = _mm256_max_epi16(v, _mm256_srli_si256(v, 2));
188    v2 = _mm256_max_epi16(v2, _mm256_srli_si256(v2, 4));
189    v2 = _mm256_max_epi16(v2, _mm256_srli_si256(v2, 8));
190    v2 = _mm256_max_epi16(v2, simd_sl_i128(v2, v2));
191    simd_extract_i16!(v2, 0)
192}
193
194#[macro_export]
195#[doc(hidden)]
196macro_rules! simd_prefix_hadd_i16 {
197    ($a:expr, $num:expr) => {
198        {
199            debug_assert!(2 * $num <= L);
200            #[cfg(target_arch = "x86")]
201            use std::arch::x86::*;
202            #[cfg(target_arch = "x86_64")]
203            use std::arch::x86_64::*;
204            let mut v = _mm256_subs_epi16($a, _mm256_set1_epi16(ZERO));
205            if $num > 4 {
206                v = _mm256_adds_epi16(v, _mm256_srli_si256(v, 8));
207            }
208            if $num > 2 {
209                v = _mm256_adds_epi16(v, _mm256_srli_si256(v, 4));
210            }
211            if $num > 1 {
212                v = _mm256_adds_epi16(v, _mm256_srli_si256(v, 2));
213            }
214            simd_extract_i16!(v, 0)
215        }
216    };
217}
218
219#[macro_export]
220#[doc(hidden)]
221macro_rules! simd_prefix_hmax_i16 {
222    ($a:expr, $num:expr) => {
223        {
224            debug_assert!(2 * $num <= L);
225            #[cfg(target_arch = "x86")]
226            use std::arch::x86::*;
227            #[cfg(target_arch = "x86_64")]
228            use std::arch::x86_64::*;
229            let mut v = $a;
230            if $num > 4 {
231                v = _mm256_max_epi16(v, _mm256_srli_si256(v, 8));
232            }
233            if $num > 2 {
234                v = _mm256_max_epi16(v, _mm256_srli_si256(v, 4));
235            }
236            if $num > 1 {
237                v = _mm256_max_epi16(v, _mm256_srli_si256(v, 2));
238            }
239            simd_extract_i16!(v, 0)
240        }
241    };
242}
243
244#[macro_export]
245#[doc(hidden)]
246macro_rules! simd_suffix_hmax_i16 {
247    ($a:expr, $num:expr) => {
248        {
249            debug_assert!(2 * $num <= L);
250            #[cfg(target_arch = "x86")]
251            use std::arch::x86::*;
252            #[cfg(target_arch = "x86_64")]
253            use std::arch::x86_64::*;
254            let mut v = $a;
255            if $num > 4 {
256                v = _mm256_max_epi16(v, _mm256_slli_si256(v, 8));
257            }
258            if $num > 2 {
259                v = _mm256_max_epi16(v, _mm256_slli_si256(v, 4));
260            }
261            if $num > 1 {
262                v = _mm256_max_epi16(v, _mm256_slli_si256(v, 2));
263            }
264            simd_extract_i16!(v, 15)
265        }
266    };
267}
268
269#[target_feature(enable = "avx2")]
270#[inline]
271pub unsafe fn simd_hargmax_i16(v: Simd, max: i16) -> usize {
272    let v2 = _mm256_cmpeq_epi16(v, _mm256_set1_epi16(max));
273    (simd_movemask_i8(v2).trailing_zeros() as usize) / 2
274}
275
276#[target_feature(enable = "avx2")]
277#[inline]
278#[allow(non_snake_case)]
279#[allow(dead_code)]
280pub unsafe fn simd_naive_prefix_scan_i16(R_max: Simd, gap_cost: Simd, _gap_cost_lane: PrefixScanConsts) -> Simd {
281    let mut curr = R_max;
282
283    for _i in 0..(L - 1) {
284        let prev = curr;
285        curr = simd_sl_i16!(curr, _mm256_setzero_si256(), 1);
286        curr = _mm256_adds_epi16(curr, gap_cost);
287        curr = _mm256_max_epi16(curr, prev);
288    }
289
290    curr
291}
292
293pub type PrefixScanConsts = Simd;
294
295#[target_feature(enable = "avx2")]
296#[inline]
297pub unsafe fn get_prefix_scan_consts(gap: Simd) -> (Simd, PrefixScanConsts) {
298    let mut shift1 = simd_sllz_i16!(gap, 1);
299    shift1 = _mm256_adds_epi16(shift1, gap);
300    let mut shift2 = simd_sllz_i16!(shift1, 2);
301    shift2 = _mm256_adds_epi16(shift2, shift1);
302    let mut shift4 = simd_sllz_i16!(shift2, 4);
303    shift4 = _mm256_adds_epi16(shift4, shift2);
304
305    let mut correct1 = _mm256_srli_si256(_mm256_shufflehi_epi16(shift4, 0b11111111), 8);
306    correct1 = _mm256_permute4x64_epi64(correct1, 0b00000101);
307    correct1 = _mm256_adds_epi16(correct1, shift4);
308
309    (correct1, shift4)
310}
311
312#[target_feature(enable = "avx2")]
313#[inline]
314#[allow(non_snake_case)]
315pub unsafe fn simd_prefix_scan_i16(R_max: Simd, gap_cost: Simd, gap_cost_lane: PrefixScanConsts) -> Simd {
316    // Optimized prefix add and max for every eight elements
317    // Note: be very careful to avoid lane-crossing which has a large penalty.
318    // Also, make sure to use as little registers as possible to avoid
319    // memory loads (latencies really matter since this is critical path).
320    // Keep the CPU busy with instructions!
321    // Note: relies on min score = 0 for speed!
322    let mut shift1 = simd_sllz_i16!(R_max, 1);
323    shift1 = _mm256_adds_epi16(shift1, gap_cost);
324    shift1 = _mm256_max_epi16(R_max, shift1);
325    let mut shift2 = simd_sllz_i16!(shift1, 2);
326    shift2 = _mm256_adds_epi16(shift2, _mm256_slli_epi16(gap_cost, 1));
327    shift2 = _mm256_max_epi16(shift1, shift2);
328    let mut shift4 = simd_sllz_i16!(shift2, 4);
329    shift4 = _mm256_adds_epi16(shift4, _mm256_slli_epi16(gap_cost, 2));
330    shift4 = _mm256_max_epi16(shift2, shift4);
331
332    // Correct the upper lane using the last element of the lower lane
333    // Make sure that the operation on the bottom lane is essentially nop
334    let mut correct1 = _mm256_shufflehi_epi16(shift4, 0b11111111);
335    correct1 = _mm256_permute4x64_epi64(correct1, 0b01010000);
336    correct1 = _mm256_adds_epi16(correct1, gap_cost_lane);
337    _mm256_max_epi16(shift4, correct1)
338}
339
340// lookup two 128-bit tables
341#[target_feature(enable = "avx2")]
342#[inline]
343pub unsafe fn halfsimd_lookup2_i16(lut1: LutSimd, lut2: LutSimd, v: HalfSimd) -> Simd {
344    let a = _mm_shuffle_epi8(lut1, v);
345    let b = _mm_shuffle_epi8(lut2, v);
346    // only the most significant bit of each byte matters for blendv
347    let mask = _mm_slli_epi16(v, 3);
348    let c = _mm_blendv_epi8(a, b, mask);
349    _mm256_cvtepi8_epi16(c)
350}
351
352#[target_feature(enable = "avx2")]
353#[inline]
354pub unsafe fn halfsimd_lookup1_i16(lut: LutSimd, v: HalfSimd) -> Simd {
355    _mm256_cvtepi8_epi16(_mm_shuffle_epi8(lut, v))
356}
357
358#[target_feature(enable = "avx2")]
359#[inline]
360pub unsafe fn halfsimd_lookup_bytes_i16(match_scores: HalfSimd, mismatch_scores: HalfSimd, a: HalfSimd, b: HalfSimd) -> Simd {
361    let mask = _mm_cmpeq_epi8(a, b);
362    let c = _mm_blendv_epi8(mismatch_scores, match_scores, mask);
363    _mm256_cvtepi8_epi16(c)
364}
365
366#[target_feature(enable = "avx2")]
367#[inline]
368pub unsafe fn halfsimd_load(ptr: *const HalfSimd) -> HalfSimd { _mm_load_si128(ptr) }
369
370#[target_feature(enable = "avx2")]
371#[inline]
372pub unsafe fn halfsimd_loadu(ptr: *const HalfSimd) -> HalfSimd { _mm_loadu_si128(ptr) }
373
374#[target_feature(enable = "avx2")]
375#[inline]
376pub unsafe fn lutsimd_load(ptr: *const LutSimd) -> LutSimd { _mm_load_si128(ptr) }
377
378#[target_feature(enable = "avx2")]
379#[inline]
380pub unsafe fn lutsimd_loadu(ptr: *const LutSimd) -> LutSimd { _mm_loadu_si128(ptr) }
381
382#[target_feature(enable = "avx2")]
383#[inline]
384pub unsafe fn halfsimd_store(ptr: *mut HalfSimd, a: HalfSimd) { _mm_store_si128(ptr, a) }
385
386#[target_feature(enable = "avx2")]
387#[inline]
388pub unsafe fn halfsimd_sub_i8(a: HalfSimd, b: HalfSimd) -> HalfSimd { _mm_sub_epi8(a, b) }
389
390#[target_feature(enable = "avx2")]
391#[inline]
392pub unsafe fn halfsimd_set1_i8(v: i8) -> HalfSimd { _mm_set1_epi8(v) }
393
394#[target_feature(enable = "avx2")]
395#[inline]
396pub unsafe fn halfsimd_get_idx(i: usize) -> usize { i }
397
398#[macro_export]
399#[doc(hidden)]
400macro_rules! halfsimd_sr_i8 {
401    ($a:expr, $b:expr, $num:expr) => {
402        {
403            debug_assert!($num <= L);
404            #[cfg(target_arch = "x86")]
405            use std::arch::x86::*;
406            #[cfg(target_arch = "x86_64")]
407            use std::arch::x86_64::*;
408            _mm_alignr_epi8($a, $b, $num as i32)
409        }
410    };
411}
412
413#[target_feature(enable = "avx2")]
414#[allow(dead_code)]
415pub unsafe fn simd_dbg_i16(v: Simd) {
416    #[repr(align(32))]
417    struct A([i16; L]);
418
419    let mut a = A([0i16; L]);
420    simd_store(a.0.as_mut_ptr() as *mut Simd, v);
421
422    for i in (0..a.0.len()).rev() {
423        print!("{:6} ", a.0[i]);
424    }
425    println!();
426}
427
428#[target_feature(enable = "avx2")]
429#[allow(dead_code)]
430pub unsafe fn halfsimd_dbg_i8(v: HalfSimd) {
431    #[repr(align(16))]
432    struct A([i8; L]);
433
434    let mut a = A([0i8; L]);
435    halfsimd_store(a.0.as_mut_ptr() as *mut HalfSimd, v);
436
437    for i in (0..a.0.len()).rev() {
438        print!("{:3} ", a.0[i]);
439    }
440    println!();
441}
442
443#[target_feature(enable = "avx2")]
444#[allow(dead_code)]
445pub unsafe fn simd_assert_vec_eq(a: Simd, b: [i16; L]) {
446    #[repr(align(32))]
447    struct A([i16; L]);
448
449    let mut arr = A([0i16; L]);
450    simd_store(arr.0.as_mut_ptr() as *mut Simd, a);
451    assert_eq!(arr.0, b);
452}
453
454#[target_feature(enable = "avx2")]
455#[allow(dead_code)]
456pub unsafe fn halfsimd_assert_vec_eq(a: HalfSimd, b: [i8; L]) {
457    #[repr(align(32))]
458    struct A([i8; L]);
459
460    let mut arr = A([0i8; L]);
461    halfsimd_store(arr.0.as_mut_ptr() as *mut HalfSimd, a);
462    assert_eq!(arr.0, b);
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_prefix_scan() {
471        #[target_feature(enable = "avx2")]
472        unsafe fn inner() {
473            #[repr(align(32))]
474            struct A([i16; L]);
475
476            let vec = A([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 12, 13, 14, 11]);
477            let gap = simd_set1_i16(0);
478            let (_, consts) = get_prefix_scan_consts(gap);
479            let res = simd_prefix_scan_i16(simd_load(vec.0.as_ptr() as *const Simd), gap, consts);
480            simd_assert_vec_eq(res, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 15, 15, 15, 15]);
481
482            let vec = A([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 12, 13, 14, 11]);
483            let gap = simd_set1_i16(-1);
484            let (_, consts) = get_prefix_scan_consts(gap);
485            let res = simd_prefix_scan_i16(simd_load(vec.0.as_ptr() as *const Simd), gap, consts);
486            simd_assert_vec_eq(res, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 14, 13, 14, 13]);
487        }
488        unsafe { inner(); }
489    }
490}