1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
use core::mem::{transmute, MaybeUninit};
use core::ops::{Add, Sub, Mul, Div};
use std::arch::x86_64::*;

#[allow(non_camel_case_types)]
pub type __mmask8 = i8;
#[allow(non_camel_case_types)]
pub type __mmask16 = i16;
#[allow(non_camel_case_types)]
pub type __mmask32 = i32;
#[allow(non_camel_case_types)]
pub type __mmask64 = i64;

#[allow(non_camel_case_types)]
pub struct __m512i(i64, i64, i64, i64, i64, i64, i64, i64);
#[allow(non_camel_case_types)]
pub struct __m512d(f64, f64, f64, f64, f64, f64, f64, f64);
#[allow(non_camel_case_types)]
pub struct __m512(
    f32, f32, f32, f32, f32, f32, f32, f32, 
    f32, f32, f32, f32, f32, f32, f32, f32
);

macro_rules! impl_mask_arith_abs {
    ($($fn_name_mask: ident, $fn_name_maskz: ident, 
    $vec_type: ty, $mask_type: ty, [$elem: ty; $iter_cnt: expr];)*) => {
        $(
pub unsafe fn $fn_name_mask(src: $vec_type, k: $mask_type, a: $vec_type) -> $vec_type {
    let src: [$elem; $iter_cnt] = transmute(src);
    let a: [$elem; $iter_cnt] = transmute(a);
    let mut dst: [$elem; $iter_cnt] = MaybeUninit::uninit().assume_init();
    for j in 0..$iter_cnt {
        dst[j] = if k & (0b1 << j) != 0 {
            a[j].abs()
        } else {
            src[j]
        };
    }
    transmute(dst)
}

pub unsafe fn $fn_name_maskz(k: $mask_type, a: $vec_type) -> $vec_type {
    let a: [$elem; $iter_cnt] = transmute(a);
    let mut dst: [$elem; $iter_cnt] = MaybeUninit::uninit().assume_init();
    for j in 0..$iter_cnt {
        dst[j] = if k & (0b1 << j) != 0 { a[j].abs() } else { 0 };
    }
    transmute(dst)
}
        )*
    };
}

impl_mask_arith_abs! {
    _mm_mask_abs_epi8,  _mm_maskz_abs_epi8,  __m128i, __mmask8, [i8; 16];
    _mm_mask_abs_epi16, _mm_maskz_abs_epi16, __m128i, __mmask8, [i16; 8];
    _mm_mask_abs_epi32, _mm_maskz_abs_epi32, __m128i, __mmask8, [i32; 4];
    _mm_mask_abs_epi64, _mm_maskz_abs_epi64, __m128i, __mmask8, [i64; 2];
    _mm256_mask_abs_epi8,  _mm256_maskz_abs_epi8,  __m256i, __mmask32, [i8; 32];
    _mm256_mask_abs_epi16, _mm256_maskz_abs_epi16, __m256i, __mmask16, [i16; 16];
    _mm256_mask_abs_epi32, _mm256_maskz_abs_epi32, __m256i, __mmask8, [i32; 8];
    _mm256_mask_abs_epi64, _mm256_maskz_abs_epi64, __m256i, __mmask8, [i64; 4];
    _mm512_mask_abs_epi8,  _mm512_maskz_abs_epi8,  __m512i, __mmask64, [i8; 64];
    _mm512_mask_abs_epi16, _mm512_maskz_abs_epi16, __m512i, __mmask32, [i16; 32];
    _mm512_mask_abs_epi32, _mm512_maskz_abs_epi32, __m512i, __mmask16, [i32; 16];
    _mm512_mask_abs_epi64, _mm512_maskz_abs_epi64, __m512i, __mmask8, [i64; 8];
}

macro_rules! impl_mask_arith_binary_vector {
    ($($fn_name_mask: ident, $fn_name_maskz: ident, 
    $vec_type: ty, $mask_type: ty, $binary_func: ident, $zero: expr,
    [$elem: ty; $iter_cnt: expr];)*) => {
        $(
pub unsafe fn $fn_name_mask(src: $vec_type, k: $mask_type, a: $vec_type, b: $vec_type) -> $vec_type {
    let src: [$elem; $iter_cnt] = transmute(src);
    let a: [$elem; $iter_cnt] = transmute(a);
    let b: [$elem; $iter_cnt] = transmute(b);
    let mut dst: [$elem; $iter_cnt] = MaybeUninit::uninit().assume_init();
    for j in 0..$iter_cnt {
        dst[j] = if k & (0b1 << j) != 0 {
            <$elem>::$binary_func(a[j], b[j])
        } else {
            src[j]
        };
    }
    transmute(dst)
}

pub unsafe fn $fn_name_maskz(k: $mask_type, a: $vec_type, b: $vec_type) -> $vec_type {
    let a: [$elem; $iter_cnt] = transmute(a);
    let b: [$elem; $iter_cnt] = transmute(b);
    let mut dst: [$elem; $iter_cnt] = MaybeUninit::uninit().assume_init();
    for j in 0..$iter_cnt {
        dst[j] = if k & (0b1 << j) != 0 { <$elem>::$binary_func(a[j], b[j]) } else { $zero };
    }
    transmute(dst)
}
        )*
    };
}

impl_mask_arith_binary_vector! {
    _mm_mask_add_epi8, _mm_maskz_add_epi8, __m128i, __mmask8, add, 0, [i8; 16];
    _mm_mask_add_epi16, _mm_maskz_add_epi16, __m128i, __mmask8, add, 0, [i16; 8];
    _mm_mask_add_epi32, _mm_maskz_add_epi32, __m128i, __mmask8, add, 0, [i32; 4];
    _mm_mask_add_epi64, _mm_maskz_add_epi64, __m128i, __mmask8, add, 0, [i64; 2];
    _mm256_mask_add_epi8,  _mm256_maskz_add_epi8,  __m256i, __mmask32, add, 0, [i8; 32];
    _mm256_mask_add_epi16, _mm256_maskz_add_epi16, __m256i, __mmask16, add, 0, [i16; 16];
    _mm256_mask_add_epi32, _mm256_maskz_add_epi32, __m256i, __mmask8, add, 0, [i32; 8];
    _mm256_mask_add_epi64, _mm256_maskz_add_epi64, __m256i, __mmask8, add, 0, [i64; 4];
    _mm512_mask_add_epi8,  _mm512_maskz_add_epi8,  __m512i, __mmask64, add, 0, [i8; 64];
    _mm512_mask_add_epi16, _mm512_maskz_add_epi16, __m512i, __mmask32, add, 0, [i16; 32];
    _mm512_mask_add_epi32, _mm512_maskz_add_epi32, __m512i, __mmask16, add, 0, [i32; 16];
    _mm512_mask_add_epi64, _mm512_maskz_add_epi64, __m512i, __mmask8, add, 0, [i64; 8];

    _mm_mask_add_pd, _mm_maskz_add_pd, __m128d, __mmask8, add, 0.0, [f64; 2];
    _mm256_mask_add_pd, _mm256_maskz_add_pd, __m256d, __mmask8, add, 0.0, [f64; 4];
    _mm512_mask_add_pd, _mm512_maskz_add_pd, __m512d, __mmask8, add, 0.0, [f64; 8];
    _mm_mask_add_ps, _mm_maskz_add_ps, __m128, __mmask8, add, 0.0, [f32; 4];
    _mm256_mask_add_ps, _mm256_maskz_add_ps, __m256, __mmask8, add, 0.0, [f32; 8];
    _mm512_mask_add_ps, _mm512_maskz_add_ps, __m512, __mmask16, add, 0.0, [f32; 16];

    _mm_mask_adds_epi8, _mm_maskz_adds_epi8, __m128i, __mmask16, saturating_add, 0, [i8; 16];
    _mm_mask_adds_epi16, _mm_maskz_adds_epi16, __m128i, __mmask8, saturating_add, 0, [i16; 8];
    _mm_mask_adds_epu8, _mm_maskz_adds_epu8, __m128i, __mmask16, saturating_add, 0, [u8; 16];
    _mm_mask_adds_epu16, _mm_maskz_adds_epu16, __m128i, __mmask8, saturating_add, 0, [u16; 8];
    _mm256_mask_adds_epi8, _mm256_maskz_adds_epi8, __m256i, __mmask32, saturating_add, 0, [i8; 32];
    _mm256_mask_adds_epi16, _mm256_maskz_adds_epi16, __m256i, __mmask16, saturating_add, 0, [i16; 16];
    _mm256_mask_adds_epu8, _mm256_maskz_adds_epu8, __m256i, __mmask32, saturating_add, 0, [u8; 32];
    _mm256_mask_adds_epu16, _mm256_maskz_adds_epu16, __m256i, __mmask16, saturating_add, 0, [u16; 16];
    _mm512_mask_adds_epi8, _mm512_maskz_adds_epi8, __m512i, __mmask64, saturating_add, 0, [i8; 64];
    _mm512_mask_adds_epi16, _mm512_maskz_adds_epi16, __m512i, __mmask32, saturating_add, 0, [i16; 32];
    _mm512_mask_adds_epu8, _mm512_maskz_adds_epu8, __m512i, __mmask64, saturating_add, 0, [u8; 64];
    _mm512_mask_adds_epu16, _mm512_maskz_adds_epu16, __m512i, __mmask32, saturating_add, 0, [u16; 32];
    
    _mm_mask_sub_epi8, _mm_maskz_sub_epi8, __m128i, __mmask8, sub, 0, [i8; 16];
    _mm_mask_sub_epi16, _mm_maskz_sub_epi16, __m128i, __mmask8, sub, 0, [i16; 8];
    _mm_mask_sub_epi32, _mm_maskz_sub_epi32, __m128i, __mmask8, sub, 0, [i32; 4];
    _mm_mask_sub_epi64, _mm_maskz_sub_epi64, __m128i, __mmask8, sub, 0, [i64; 2];
    _mm256_mask_sub_epi8,  _mm256_maskz_sub_epi8,  __m256i, __mmask32, sub, 0, [i8; 32];
    _mm256_mask_sub_epi16, _mm256_maskz_sub_epi16, __m256i, __mmask16, sub, 0, [i16; 16];
    _mm256_mask_sub_epi32, _mm256_maskz_sub_epi32, __m256i, __mmask8, sub, 0, [i32; 8];
    _mm256_mask_sub_epi64, _mm256_maskz_sub_epi64, __m256i, __mmask8, sub, 0, [i64; 4];
    _mm512_mask_sub_epi8,  _mm512_maskz_sub_epi8,  __m512i, __mmask64, sub, 0, [i8; 64];
    _mm512_mask_sub_epi16, _mm512_maskz_sub_epi16, __m512i, __mmask32, sub, 0, [i16; 32];
    _mm512_mask_sub_epi32, _mm512_maskz_sub_epi32, __m512i, __mmask16, sub, 0, [i32; 16];
    _mm512_mask_sub_epi64, _mm512_maskz_sub_epi64, __m512i, __mmask8, sub, 0, [i64; 8];

    _mm_mask_sub_pd, _mm_maskz_sub_pd, __m128d, __mmask8, sub, 0.0, [f64; 2];
    _mm256_mask_sub_pd, _mm256_maskz_sub_pd, __m256d, __mmask8, sub, 0.0, [f64; 4];
    _mm512_mask_sub_pd, _mm512_maskz_sub_pd, __m512d, __mmask8, sub, 0.0, [f64; 8];
    _mm_mask_sub_ps, _mm_maskz_sub_ps, __m128, __mmask8, sub, 0.0, [f32; 4];
    _mm256_mask_sub_ps, _mm256_maskz_sub_ps, __m256, __mmask8, sub, 0.0, [f32; 8];
    _mm512_mask_sub_ps, _mm512_maskz_sub_ps, __m512, __mmask16, sub, 0.0, [f32; 16];

    _mm_mask_subs_epi8, _mm_maskz_subs_epi8, __m128i, __mmask16, saturating_sub, 0, [i8; 16];
    _mm_mask_subs_epi16, _mm_maskz_subs_epi16, __m128i, __mmask8, saturating_sub, 0, [i16; 8];
    _mm_mask_subs_epu8, _mm_maskz_subs_epu8, __m128i, __mmask16, saturating_sub, 0, [u8; 16];
    _mm_mask_subs_epu16, _mm_maskz_subs_epu16, __m128i, __mmask8, saturating_sub, 0, [u16; 8];
    _mm256_mask_subs_epi8, _mm256_maskz_subs_epi8, __m256i, __mmask32, saturating_sub, 0, [i8; 32];
    _mm256_mask_subs_epi16, _mm256_maskz_subs_epi16, __m256i, __mmask16, saturating_sub, 0, [i16; 16];
    _mm256_mask_subs_epu8, _mm256_maskz_subs_epu8, __m256i, __mmask32, saturating_sub, 0, [u8; 32];
    _mm256_mask_subs_epu16, _mm256_maskz_subs_epu16, __m256i, __mmask16, saturating_sub, 0, [u16; 16];
    _mm512_mask_subs_epi8, _mm512_maskz_subs_epi8, __m512i, __mmask64, saturating_sub, 0, [i8; 64];
    _mm512_mask_subs_epi16, _mm512_maskz_subs_epi16, __m512i, __mmask32, saturating_sub, 0, [i16; 32];
    _mm512_mask_subs_epu8, _mm512_maskz_subs_epu8, __m512i, __mmask64, saturating_sub, 0, [u8; 64];
    _mm512_mask_subs_epu16, _mm512_maskz_subs_epu16, __m512i, __mmask32, saturating_sub, 0, [u16; 32];

    _mm_mask_mul_epi32, _mm_maskz_mul_epi32, __m128i, __mmask8, mul, 0, [i32; 4];
    _mm256_mask_mul_epi32, _mm256_maskz_mul_epi32, __m256i, __mmask8, mul, 0, [i32; 8];
    _mm512_mask_mul_epi32, _mm512_maskz_mul_epi32, __m512i, __mmask8, mul, 0, [i32; 16];
    _mm_mask_mul_epu32, _mm_maskz_mul_epu32, __m128i, __mmask8, mul, 0, [u32; 4];
    _mm256_mask_mul_epu32, _mm256_maskz_mul_epu32, __m256i, __mmask8, mul, 0, [u32; 8];
    _mm512_mask_mul_epu32, _mm512_maskz_mul_epu32, __m512i, __mmask8, mul, 0, [u32; 16];
    
    _mm_mask_mul_pd, _mm_maskz_mul_pd, __m128d, __mmask8, mul, 0.0, [f64; 2];
    _mm256_mask_mul_pd, _mm256_maskz_mul_pd, __m256d, __mmask8, mul, 0.0, [f64; 4];
    _mm512_mask_mul_pd, _mm512_maskz_mul_pd, __m512d, __mmask8, mul, 0.0, [f64; 8];
    _mm_mask_mul_ps, _mm_maskz_mul_ps, __m128, __mmask8, mul, 0.0, [f32; 4];
    _mm256_mask_mul_ps, _mm256_maskz_mul_ps, __m256, __mmask8, mul, 0.0, [f32; 8];
    _mm512_mask_mul_ps, _mm512_maskz_mul_ps, __m512, __mmask16, mul, 0.0, [f32; 16];

    _mm_mask_div_pd, _mm_maskz_div_pd, __m128d, __mmask8, div, 0.0, [f64; 2];
    _mm256_mask_div_pd, _mm256_maskz_div_pd, __m256d, __mmask8, div, 0.0, [f64; 4];
    _mm512_mask_div_pd, _mm512_maskz_div_pd, __m512d, __mmask8, div, 0.0, [f64; 8];
    _mm_mask_div_ps, _mm_maskz_div_ps, __m128, __mmask8, div, 0.0, [f32; 4];
    _mm256_mask_div_ps, _mm256_maskz_div_ps, __m256, __mmask8, div, 0.0, [f32; 8];
    _mm512_mask_div_ps, _mm512_maskz_div_ps, __m512, __mmask16, div, 0.0, [f32; 16];
}

macro_rules! impl_mask_arith_binary_scalar {
    ($($fn_name_mask: ident, $fn_name_maskz: ident, 
    $vec_type: ty, $mask_type: ty, $binary_func: ident, $zero: expr,
    [$elem: ty; $iter_cnt: expr];)*) => {
        $(
pub unsafe fn $fn_name_mask(src: $vec_type, k: $mask_type, a: $vec_type, b: $vec_type) -> $vec_type {
    let src: [$elem; $iter_cnt] = transmute(src);
    let a: [$elem; $iter_cnt] = transmute(a);
    let b: [$elem; $iter_cnt] = transmute(b);
    let mut dst: [$elem; $iter_cnt] = MaybeUninit::uninit().assume_init();
    for i in 1..$iter_cnt {
        dst[i] = a[i];
    }
    dst[0] = if k & 0b1 != 0 {
        <$elem>::$binary_func(a[0], b[0])
    } else {
        src[0]
    };
    transmute(dst)
}

pub unsafe fn $fn_name_maskz(k: $mask_type, a: $vec_type, b: $vec_type) -> $vec_type {
    let a: [$elem; $iter_cnt] = transmute(a);
    let b: [$elem; $iter_cnt] = transmute(b);
    let mut dst: [$elem; $iter_cnt] = MaybeUninit::uninit().assume_init();
    for i in 1..$iter_cnt {
        dst[i] = a[i];
    }
    dst[0] = if k & 0b1 != 0 { <$elem>::$binary_func(a[0], b[0]) } else { $zero };
    transmute(dst)
}
        )*
    };
}

impl_mask_arith_binary_scalar! {
    _mm_mask_add_ss, _mm_maskz_add_ss, __m128d, __mmask8, add, 0.0, [f32; 4];
    _mm_mask_add_sd, _mm_maskz_add_sd, __m128, __mmask8, add, 0.0, [f64; 2];
    _mm_mask_sub_ss, _mm_maskz_sub_ss, __m128d, __mmask8, sub, 0.0, [f32; 4];
    _mm_mask_sub_sd, _mm_maskz_sub_sd, __m128, __mmask8, sub, 0.0, [f64; 2];
    _mm_mask_mul_ss, _mm_maskz_mul_ss, __m128d, __mmask8, mul, 0.0, [f32; 4];
    _mm_mask_mul_sd, _mm_maskz_mul_sd, __m128, __mmask8, mul, 0.0, [f64; 2];
    _mm_mask_div_ss, _mm_maskz_div_ss, __m128d, __mmask8, div, 0.0, [f32; 4];
    _mm_mask_div_sd, _mm_maskz_div_sd, __m128, __mmask8, div, 0.0, [f64; 2];
}

/*
    _mm_mask_add_round_pd, _mm_maskz_add_round_pd, __m128d, __mmask8, add, 0.0, [f64; 2];
    _mm512_mask_add_round_pd, _mm512_maskz_add_round_pd, __m512d, __mmask8, add, 0.0, [f64; 8];
    _mm_mask_add_round_ps, _mm_maskz_add_round_ps, __m128, __mmask8, add, 0.0, [f32; 4];
    _mm512_mask_add_round_ps, _mm512_maskz_add_round_ps, __m512, __mmask16, add, 0.0, [f32; 16];
*/