#pragma once
#include <x86intrin.h>
#ifdef WIN32
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include <cmath>
#include <cstdint>
#include <type_traits>
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
namespace megdnn {
namespace x86 {
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void _mm256_reduce_two_epi32_to_ptr(
__m256i& a, __m256i& b, int32_t* output_ptr) {
__m256i vec_zero = _mm256_setzero_si256();
a = _mm256_hadd_epi32(a, b);
a = _mm256_hadd_epi32(a, vec_zero);
a = _mm256_add_epi32(a, _mm256_permute2x128_si256(a, vec_zero, 0x31));
output_ptr[0] = _mm256_extract_epi32(a, 0);
output_ptr[1] = _mm256_extract_epi32(a, 1);
}
template <typename T>
static inline void interleave_helper(
const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) {
int k = 0;
for (; k < ksize; k++) {
*outptr++ = *inptr++;
}
for (; k < unroll_k; k++) {
*outptr++ = val;
}
}
static inline void interleave_helper_add_128(
const int8_t*& inptr, uint8_t*& outptr, int unroll_k, int ksize,
uint8_t val = 0) {
int k = 0;
for (; k < ksize; k++) {
*outptr++ = static_cast<uint8_t>((*inptr++) + 128u);
}
for (; k < unroll_k; k++) {
*outptr++ = static_cast<uint8_t>(val + 128u);
}
}
template <typename T>
static inline void interleave_helper_no_inc(
T* outptr, const T* inptr, int unroll_k, int ksize, T val = 0) {
int k = 0;
for (; k < ksize; k++) {
*outptr++ = *inptr++;
}
for (; k < unroll_k; k++) {
*outptr++ = val;
}
}
static inline void interleave_2x16_pad(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, int k) {
interleave_helper_no_inc(out, in0, 16, k);
interleave_helper_no_inc(out + 16, in1, 16, k);
}
static inline void interleave_4x16_pad(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, const dt_int8* in2,
const dt_int8* in3, int k) {
interleave_helper_no_inc(out, in0, 16, k);
interleave_helper_no_inc(out + 16, in1, 16, k);
interleave_helper_no_inc(out + 32, in2, 16, k);
interleave_helper_no_inc(out + 48, in3, 16, k);
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void interleave_2x16(
dt_int8* out, const dt_int8* in0, const dt_int8* in1) {
_mm_storeu_si128((__m128i*)out, _mm_loadu_si128((const __m128i*)in0));
_mm_storeu_si128((__m128i*)(out + 16), _mm_loadu_si128((const __m128i*)in1));
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void interleave_4x16(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, const dt_int8* in2,
const dt_int8* in3) {
_mm_storeu_si128((__m128i*)out, _mm_loadu_si128((const __m128i*)in0));
_mm_storeu_si128((__m128i*)(out + 16), _mm_loadu_si128((const __m128i*)in1));
_mm_storeu_si128((__m128i*)(out + 32), _mm_loadu_si128((const __m128i*)in2));
_mm_storeu_si128((__m128i*)(out + 48), _mm_loadu_si128((const __m128i*)in3));
}
template <typename T>
static inline void interleave_4(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
interleave_helper(inptr1, outptr, unroll_k, size, val);
interleave_helper(inptr2, outptr, unroll_k, size, val);
interleave_helper(inptr3, outptr, unroll_k, size, val);
}
}
static inline void interleave_4_add_128(
const int8_t*& inptr0, const int8_t*& inptr1, const int8_t*& inptr2,
const int8_t*& inptr3, uint8_t*& outptr, int unroll_k, int ksize,
uint8_t val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
interleave_helper_add_128(inptr0, outptr, unroll_k, size, val);
interleave_helper_add_128(inptr1, outptr, unroll_k, size, val);
interleave_helper_add_128(inptr2, outptr, unroll_k, size, val);
interleave_helper_add_128(inptr3, outptr, unroll_k, size, val);
}
}
template <typename T>
static inline void interleave_12(
const T* (&input)[12], T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
for (int i = 0; i < 12; i++)
interleave_helper(input[i], outptr, unroll_k, size, val);
}
}
static inline void interleave_12_add_128(
const int8_t* (&input)[12], uint8_t*& outptr, int unroll_k, int ksize,
uint8_t val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
for (int i = 0; i < 12; i++)
interleave_helper_add_128(input[i], outptr, unroll_k, size, val);
}
}
template <typename T>
static inline void interleave_16(
const T* (&input)[16], T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
for (int i = 0; i < 16; i++)
interleave_helper(input[i], outptr, unroll_k, size, val);
}
}
template <typename T>
static inline void interleave_32(
const T* (&input)[32], T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
for (int i = 0; i < 32; i++)
interleave_helper(input[i], outptr, unroll_k, size, val);
}
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void interleave_4x4_4_b_add_128(
const int8_t*& input0, const int8_t*& input1, const int8_t*& input2,
const int8_t*& input3, uint8_t*& outptr) {
__m128i const_128 = _mm_set1_epi8(-128);
__m128i R0 = _mm_loadu_si128((__m128i*)input0); __m128i R1 = _mm_loadu_si128((__m128i*)input1); __m128i R2 = _mm_loadu_si128((__m128i*)input2); __m128i R3 = _mm_loadu_si128((__m128i*)input3);
R0 = _mm_add_epi8(R0, const_128);
R1 = _mm_add_epi8(R1, const_128);
R2 = _mm_add_epi8(R2, const_128);
R3 = _mm_add_epi8(R3, const_128);
__m128i R01L = _mm_unpacklo_epi32(R0, R1); __m128i R01H = _mm_unpackhi_epi32(R0, R1); __m128i R23L = _mm_unpacklo_epi32(R2, R3); __m128i R23H = _mm_unpackhi_epi32(R2, R3);
_mm_storeu_si128((__m128i*)(outptr), _mm_unpacklo_epi64(R01L, R23L));
_mm_storeu_si128((__m128i*)(outptr + 16), _mm_unpackhi_epi64(R01L, R23L));
_mm_storeu_si128((__m128i*)(outptr + 32), _mm_unpacklo_epi64(R01H, R23H));
_mm_storeu_si128((__m128i*)(outptr + 48), _mm_unpackhi_epi64(R01H, R23H));
input0 += 16;
input1 += 16;
input2 += 16;
input3 += 16;
outptr += 64;
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void interleave_12x4_4_b_add_128(
const int8_t* (&input)[12], uint8_t*& outptr) {
__m128i O0[3], O1[3], O2[3], O3[3];
__m128i const_128 = _mm_set1_epi8(-128);
for (int i = 0; i < 3; i++) {
__m128i R0 = _mm_loadu_si128(((__m128i*)input[i * 4 + 0])); __m128i R1 = _mm_loadu_si128((__m128i*)(input[i * 4 + 1])); __m128i R2 = _mm_loadu_si128((__m128i*)(input[i * 4 + 2])); __m128i R3 = _mm_loadu_si128((__m128i*)(input[i * 4 + 3]));
R0 = _mm_add_epi8(R0, const_128);
R1 = _mm_add_epi8(R1, const_128);
R2 = _mm_add_epi8(R2, const_128);
R3 = _mm_add_epi8(R3, const_128);
__m128i R01L = _mm_unpacklo_epi32(R0, R1); __m128i R01H = _mm_unpackhi_epi32(R0, R1); __m128i R23L = _mm_unpacklo_epi32(R2, R3); __m128i R23H = _mm_unpackhi_epi32(R2, R3);
O0[i] = _mm_unpacklo_epi64(R01L, R23L);
O1[i] = _mm_unpackhi_epi64(R01L, R23L);
O2[i] = _mm_unpacklo_epi64(R01H, R23H);
O3[i] = _mm_unpackhi_epi64(R01H, R23H);
}
for (int i = 0; i < 3; i++) {
_mm_storeu_si128((__m128i*)outptr, O0[i]);
outptr += 16;
}
for (int i = 0; i < 3; i++) {
_mm_storeu_si128((__m128i*)outptr, O1[i]);
outptr += 16;
}
for (int i = 0; i < 3; i++) {
_mm_storeu_si128((__m128i*)outptr, O2[i]);
outptr += 16;
}
for (int i = 0; i < 3; i++) {
_mm_storeu_si128((__m128i*)outptr, O3[i]);
outptr += 16;
}
for (auto& ptr : input) {
ptr += 16;
}
}
template <typename T>
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void interleave_16x4_4_b(const T* (&input)[16], T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_16x4_4_b only support uint8_t and int8_t");
__m128i O0[4], O1[4], O2[4], O3[4];
for (int i = 0; i < 4; i++) {
__m128i R0 = _mm_loadu_si128(((__m128i*)input[i * 4 + 0])); __m128i R1 = _mm_loadu_si128((__m128i*)(input[i * 4 + 1])); __m128i R01L = _mm_unpacklo_epi32(R0, R1); __m128i R01H = _mm_unpackhi_epi32(R0, R1);
__m128i R2 = _mm_loadu_si128((__m128i*)(input[i * 4 + 2])); __m128i R3 = _mm_loadu_si128((__m128i*)(input[i * 4 + 3])); __m128i R23L = _mm_unpacklo_epi32(R2, R3); __m128i R23H = _mm_unpackhi_epi32(R2, R3);
O0[i] = _mm_unpacklo_epi64(R01L, R23L);
O1[i] = _mm_unpackhi_epi64(R01L, R23L);
O2[i] = _mm_unpacklo_epi64(R01H, R23H);
O3[i] = _mm_unpackhi_epi64(R01H, R23H);
}
for (int i = 0; i < 4; i++) {
_mm_storeu_si128((__m128i*)outptr, O0[i]);
outptr += 16;
}
for (int i = 0; i < 4; i++) {
_mm_storeu_si128((__m128i*)outptr, O1[i]);
outptr += 16;
}
for (int i = 0; i < 4; i++) {
_mm_storeu_si128((__m128i*)outptr, O2[i]);
outptr += 16;
}
for (int i = 0; i < 4; i++) {
_mm_storeu_si128((__m128i*)outptr, O3[i]);
outptr += 16;
}
for (auto& ptr : input) {
ptr += 16;
}
}
template <typename T>
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void interleave_32x4_4_b(const T* (&input)[32], T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_32x4_4_b only support uint8_t and int8_t");
__m128i O0[8], O1[8], O2[8], O3[8];
for (int i = 0; i < 8; i++) {
__m128i R0 = _mm_loadu_si128(((__m128i*)input[i * 4 + 0])); __m128i R1 = _mm_loadu_si128((__m128i*)(input[i * 4 + 1])); __m128i R01L = _mm_unpacklo_epi32(R0, R1); __m128i R01H = _mm_unpackhi_epi32(R0, R1);
__m128i R2 = _mm_loadu_si128((__m128i*)(input[i * 4 + 2])); __m128i R3 = _mm_loadu_si128((__m128i*)(input[i * 4 + 3])); __m128i R23L = _mm_unpacklo_epi32(R2, R3); __m128i R23H = _mm_unpackhi_epi32(R2, R3);
O0[i] = _mm_unpacklo_epi64(R01L, R23L);
O1[i] = _mm_unpackhi_epi64(R01L, R23L);
O2[i] = _mm_unpacklo_epi64(R01H, R23H);
O3[i] = _mm_unpackhi_epi64(R01H, R23H);
}
for (int i = 0; i < 8; i++) {
_mm_storeu_si128((__m128i*)outptr, O0[i]);
outptr += 16;
}
for (int i = 0; i < 8; i++) {
_mm_storeu_si128((__m128i*)outptr, O1[i]);
outptr += 16;
}
for (int i = 0; i < 8; i++) {
_mm_storeu_si128((__m128i*)outptr, O2[i]);
outptr += 16;
}
for (int i = 0; i < 8; i++) {
_mm_storeu_si128((__m128i*)outptr, O3[i]);
outptr += 16;
}
for (auto& ptr : input) {
ptr += 16;
}
}
static inline void naive_transpose_16xn(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, const dt_int8* in2,
const dt_int8* in3, const dt_int8* in4, const dt_int8* in5, const dt_int8* in6,
const dt_int8* in7, const dt_int8* in8, const dt_int8* in9, const dt_int8* in10,
const dt_int8* in11, const dt_int8* in12, const dt_int8* in13,
const dt_int8* in14, const dt_int8* in15, int n) {
for (int i = 0; i < n; ++i) {
#define cb(iter, a...) *out++ = *in##iter++;
UNROLL_CALL(16, cb);
#undef cb
}
}
static inline void naive_transpose_nk_k2(
dt_int8* out, const dt_int8* in, int ldin, int n, int k, int n_unroll) {
constexpr int k_step = 2;
for (int k_iter = 0; k_iter < k; k_iter += k_step) {
for (int n_iter = 0; n_iter < n; ++n_iter) {
*out++ = *(in + n_iter * ldin + k_iter);
if (k_iter + 1 < k) {
*out++ = *(in + n_iter * ldin + k_iter + 1);
} else {
*out++ = 0;
}
}
for (int n_iter = n; n_iter < n_unroll; ++n_iter) {
*out++ = 0;
*out++ = 0;
}
}
}
static inline void naive_transpose_16xk_k2(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, const dt_int8* in2,
const dt_int8* in3, const dt_int8* in4, const dt_int8* in5, const dt_int8* in6,
const dt_int8* in7, const dt_int8* in8, const dt_int8* in9, const dt_int8* in10,
const dt_int8* in11, const dt_int8* in12, const dt_int8* in13,
const dt_int8* in14, const dt_int8* in15, int k_max) {
constexpr int k_step = 2;
const int k_end = k_max / k_step * k_step;
const int k_remain = k_max - k_end;
for (int k = 0; k < k_end; k += k_step) {
#define cb(iter, a...) \
*out++ = *in##iter++; \
*out++ = *in##iter++;
UNROLL_CALL(16, cb);
#undef cb
}
if (k_remain > 0) {
#define cb(iter, a...) \
*out++ = *in##iter++; \
*out++ = 0;
UNROLL_CALL(16, cb);
#undef cb
}
}
static inline void naive_transpose_8xk_k2(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, const dt_int8* in2,
const dt_int8* in3, const dt_int8* in4, const dt_int8* in5, const dt_int8* in6,
const dt_int8* in7, int k_max) {
constexpr int k_step = 2;
const int k_end = k_max / k_step * k_step;
const int k_remain = k_max - k_end;
for (int k = 0; k < k_end; k += k_step) {
#define cb(iter, a...) \
*out++ = *in##iter++; \
*out++ = *in##iter++;
UNROLL_CALL(8, cb);
#undef cb
}
if (k_remain > 0) {
#define cb(iter, a...) \
*out++ = *in##iter++; \
*out++ = 0;
UNROLL_CALL(8, cb);
#undef cb
}
}
static inline void naive_transpose_kn(
dt_int8* out, const dt_int8* in, int ldin, int k, int n) {
for (int n_iter = 0; n_iter < n; ++n_iter) {
for (int k_iter = 0; k_iter < k; ++k_iter) {
*out++ = *(in + k_iter * ldin + n_iter);
}
}
}
template <typename OutType>
static inline void naive_transpose_kn_pad(
OutType* out, const dt_int8* in, int ldin, int k, int n, int k_unroll,
int n_unroll, OutType pad = 0) {
for (int n_iter = 0; n_iter < n_unroll; ++n_iter) {
for (int k_iter = 0; k_iter < k_unroll; ++k_iter) {
if (k_iter < k && n_iter < n) {
*out++ = *(in + k_iter * ldin + n_iter);
} else {
*out++ = pad;
}
}
}
}
template <typename T>
static inline void transpose_4(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T* outptr, int interleave, int size, T val = 0) {
megdnn_assert(size <= interleave);
int i = 0;
for (; i < size; i++) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
}
for (; i < interleave; i++) {
*outptr++ = val;
*outptr++ = val;
*outptr++ = val;
*outptr++ = val;
}
}
template <typename T>
static inline void transpose_2_no_inc(
const T* inptr0, const T* inptr1, T* outptr, int interleave, int size,
T val = 0) {
megdnn_assert(size <= interleave);
int i = 0;
for (; i < size; i++) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
}
for (; i < interleave; i++) {
*outptr++ = val;
*outptr++ = val;
}
}
static inline void transpose_4_add_128(
const int8_t*& inptr0, const int8_t*& inptr1, const int8_t*& inptr2,
const int8_t*& inptr3, uint8_t* outptr, int interleave, int size,
uint8_t val = 0) {
megdnn_assert(size <= interleave);
int i = 0;
for (; i < size; i++) {
*outptr++ = static_cast<uint8_t>((*inptr0++) + 128u);
*outptr++ = static_cast<uint8_t>((*inptr1++) + 128u);
*outptr++ = static_cast<uint8_t>((*inptr2++) + 128u);
*outptr++ = static_cast<uint8_t>((*inptr3++) + 128u);
}
for (; i < interleave; i++) {
*outptr++ = static_cast<uint8_t>(val + 128u);
*outptr++ = static_cast<uint8_t>(val + 128u);
*outptr++ = static_cast<uint8_t>(val + 128u);
*outptr++ = static_cast<uint8_t>(val + 128u);
}
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void transpose_2x32_no_inc(
const int8_t* inptr0, const int8_t* inptr1, int8_t* outptr) {
__m256i r0 = _mm256_loadu_si256((__m256i*)(inptr0));
__m256i r1 = _mm256_loadu_si256((__m256i*)(inptr1));
__m256i r01l = _mm256_unpacklo_epi8(r0, r1);
__m256i r01h = _mm256_unpackhi_epi8(r0, r1);
_mm_storeu_si128((__m128i*)outptr, _mm256_extracti128_si256(r01l, 0));
_mm_storeu_si128((__m128i*)(outptr + 16), _mm256_extracti128_si256(r01h, 0));
_mm_storeu_si128((__m128i*)(outptr + 32), _mm256_extracti128_si256(r01l, 1));
_mm_storeu_si128((__m128i*)(outptr + 48), _mm256_extracti128_si256(r01h, 1));
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void transpose_2x16_no_inc(
const int8_t* inptr0, const int8_t* inptr1, int8_t* outptr) {
__m128i r0 = _mm_loadu_si128((__m128i*)inptr0);
__m128i r1 = _mm_loadu_si128((__m128i*)inptr1);
__m128i r01l = _mm_unpacklo_epi8(r0, r1);
__m128i r01h = _mm_unpackhi_epi8(r0, r1);
_mm_storeu_si128((__m128i*)outptr, r01l);
_mm_storeu_si128((__m128i*)(outptr + 16), r01h);
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void transpose_2x8_no_inc(
const int8_t* inptr0, const int8_t* inptr1, int8_t* outptr) {
__m128i r0 = _mm_loadl_epi64((__m128i*)inptr0);
__m128i r1 = _mm_loadl_epi64((__m128i*)inptr1);
__m128i r01l = _mm_unpacklo_epi8(r0, r1);
_mm_storeu_si128((__m128i*)outptr, r01l);
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline __m256i _mm256_cvtepi8_epi16_from_ptr(const int8_t* ptr) {
return _mm256_cvtepi8_epi16(_mm_loadu_si128((__m128i*)ptr));
}
MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline __m128i _mm_cvtepi8_epi16_from_ptr(const int8_t* ptr) {
return _mm_cvtepi8_epi16(_mm_loadl_epi64((__m128i*)ptr));
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void transpose_2x16_k2_int8_to_int16(
const int8_t* inptr0, const int8_t* inptr1, int16_t* outptr) {
__m256i r0 = _mm256_cvtepi8_epi16_from_ptr(inptr0);
__m256i r1 = _mm256_cvtepi8_epi16_from_ptr(inptr1);
__m256i r01l = _mm256_unpacklo_epi32(r0, r1);
__m256i r01h = _mm256_unpackhi_epi32(r0, r1);
_mm_storeu_si128((__m128i*)(outptr + 0), _mm256_extracti128_si256(r01l, 0));
_mm_storeu_si128((__m128i*)(outptr + 8), _mm256_extracti128_si256(r01h, 0));
_mm_storeu_si128((__m128i*)(outptr + 16), _mm256_extracti128_si256(r01l, 1));
_mm_storeu_si128((__m128i*)(outptr + 24), _mm256_extracti128_si256(r01h, 1));
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void transpose_km_2x16_k2_tile4_int8_to_int16(
const int8_t* inptr0, const int8_t* inptr1, int16_t* outptr, int tile_step) {
__m256i r0 = _mm256_cvtepi8_epi16_from_ptr(inptr0);
__m256i r1 = _mm256_cvtepi8_epi16_from_ptr(inptr1);
__m256i r01l = _mm256_unpacklo_epi16(r0, r1);
__m256i r01h = _mm256_unpackhi_epi16(r0, r1);
_mm_storeu_si128(
(__m128i*)(outptr + 0 * tile_step), _mm256_extracti128_si256(r01l, 0));
_mm_storeu_si128(
(__m128i*)(outptr + 1 * tile_step), _mm256_extracti128_si256(r01h, 0));
_mm_storeu_si128(
(__m128i*)(outptr + 2 * tile_step), _mm256_extracti128_si256(r01l, 1));
_mm_storeu_si128(
(__m128i*)(outptr + 3 * tile_step), _mm256_extracti128_si256(r01h, 1));
}
MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void transpose_8x16_k2(
dt_int8* out, const dt_int8* in0, const dt_int8* in1, const dt_int8* in2,
const dt_int8* in3, const dt_int8* in4, const dt_int8* in5, const dt_int8* in6,
const dt_int8* in7) {
__m128i r0 = _mm_loadu_si128((__m128i*)in0);
__m128i r1 = _mm_loadu_si128((__m128i*)in1);
__m128i r2 = _mm_loadu_si128((__m128i*)in2);
__m128i r3 = _mm_loadu_si128((__m128i*)in3);
__m128i r4 = _mm_loadu_si128((__m128i*)in4);
__m128i r5 = _mm_loadu_si128((__m128i*)in5);
__m128i r6 = _mm_loadu_si128((__m128i*)in6);
__m128i r7 = _mm_loadu_si128((__m128i*)in7);
__m128i rab0123 = _mm_unpacklo_epi16(r0, r1);
__m128i rab4567 = _mm_unpackhi_epi16(r0, r1);
__m128i rcd0123 = _mm_unpacklo_epi16(r2, r3);
__m128i rcd4567 = _mm_unpackhi_epi16(r2, r3);
__m128i ref0123 = _mm_unpacklo_epi16(r4, r5);
__m128i ref4567 = _mm_unpackhi_epi16(r4, r5);
__m128i rgh0123 = _mm_unpacklo_epi16(r6, r7);
__m128i rgh4567 = _mm_unpackhi_epi16(r6, r7);
__m128i rabcd01 = _mm_unpacklo_epi32(rab0123, rcd0123);
__m128i rabcd23 = _mm_unpackhi_epi32(rab0123, rcd0123);
__m128i rabcd45 = _mm_unpacklo_epi32(rab4567, rcd4567);
__m128i rabcd67 = _mm_unpackhi_epi32(rab4567, rcd4567);
__m128i refgh01 = _mm_unpacklo_epi32(ref0123, rgh0123);
__m128i refgh23 = _mm_unpackhi_epi32(ref0123, rgh0123);
__m128i refgh45 = _mm_unpacklo_epi32(ref4567, rgh4567);
__m128i refgh67 = _mm_unpackhi_epi32(ref4567, rgh4567);
__m128i rabcdefgh0 = _mm_unpacklo_epi64(rabcd01, refgh01);
__m128i rabcdefgh1 = _mm_unpackhi_epi64(rabcd01, refgh01);
__m128i rabcdefgh2 = _mm_unpacklo_epi64(rabcd23, refgh23);
__m128i rabcdefgh3 = _mm_unpackhi_epi64(rabcd23, refgh23);
__m128i rabcdefgh4 = _mm_unpacklo_epi64(rabcd45, refgh45);
__m128i rabcdefgh5 = _mm_unpackhi_epi64(rabcd45, refgh45);
__m128i rabcdefgh6 = _mm_unpacklo_epi64(rabcd67, refgh67);
__m128i rabcdefgh7 = _mm_unpackhi_epi64(rabcd67, refgh67);
_mm_storeu_si128((__m128i*)(out + 0 * 16), rabcdefgh0);
_mm_storeu_si128((__m128i*)(out + 1 * 16), rabcdefgh1);
_mm_storeu_si128((__m128i*)(out + 2 * 16), rabcdefgh2);
_mm_storeu_si128((__m128i*)(out + 3 * 16), rabcdefgh3);
_mm_storeu_si128((__m128i*)(out + 4 * 16), rabcdefgh4);
_mm_storeu_si128((__m128i*)(out + 5 * 16), rabcdefgh5);
_mm_storeu_si128((__m128i*)(out + 6 * 16), rabcdefgh6);
_mm_storeu_si128((__m128i*)(out + 7 * 16), rabcdefgh7);
}
MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void transpose_km_2x8_k2_tile4_int8_to_int16(
const int8_t* inptr0, const int8_t* inptr1, int16_t* outptr, int tile_step) {
__m128i r0 = _mm_cvtepi8_epi16_from_ptr(inptr0);
__m128i r1 = _mm_cvtepi8_epi16_from_ptr(inptr1);
__m128i r01l = _mm_unpacklo_epi16(r0, r1);
__m128i r01h = _mm_unpackhi_epi16(r0, r1);
_mm_storeu_si128((__m128i*)(outptr + 0 * tile_step), r01l);
_mm_storeu_si128((__m128i*)(outptr + 1 * tile_step), r01h);
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void transpose_4x16_k2_int8_to_int16(
const int8_t* inptr0, const int8_t* inptr1, const int8_t* inptr2,
const int8_t* inptr3, int16_t* outptr) {
__m256i r0 = _mm256_cvtepi8_epi16_from_ptr(inptr0);
__m256i r1 = _mm256_cvtepi8_epi16_from_ptr(inptr1);
__m256i r2 = _mm256_cvtepi8_epi16_from_ptr(inptr2);
__m256i r3 = _mm256_cvtepi8_epi16_from_ptr(inptr3);
__m256i r01l = _mm256_unpacklo_epi32(r0, r1);
__m256i r01h = _mm256_unpackhi_epi32(r0, r1);
__m256i r23l = _mm256_unpacklo_epi32(r2, r3);
__m256i r23h = _mm256_unpackhi_epi32(r2, r3);
__m256i out_0_4 = _mm256_unpacklo_epi64(r01l, r23l);
__m256i out_1_5 = _mm256_unpackhi_epi64(r01l, r23l);
__m256i out_2_6 = _mm256_unpacklo_epi64(r01h, r23h);
__m256i out_3_7 = _mm256_unpackhi_epi64(r01h, r23h);
_mm_storeu_si128((__m128i*)(outptr + 0), _mm256_extracti128_si256(out_0_4, 0));
_mm_storeu_si128((__m128i*)(outptr + 8), _mm256_extracti128_si256(out_1_5, 0));
_mm_storeu_si128((__m128i*)(outptr + 16), _mm256_extracti128_si256(out_2_6, 0));
_mm_storeu_si128((__m128i*)(outptr + 24), _mm256_extracti128_si256(out_3_7, 0));
_mm_storeu_si128((__m128i*)(outptr + 32), _mm256_extracti128_si256(out_0_4, 1));
_mm_storeu_si128((__m128i*)(outptr + 40), _mm256_extracti128_si256(out_1_5, 1));
_mm_storeu_si128((__m128i*)(outptr + 48), _mm256_extracti128_si256(out_2_6, 1));
_mm_storeu_si128((__m128i*)(outptr + 56), _mm256_extracti128_si256(out_3_7, 1));
}
MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void transpose_4x8_k2_int8_to_int16(
const int8_t* inptr0, const int8_t* inptr1, const int8_t* inptr2,
const int8_t* inptr3, int16_t* outptr) {
__m128i r0 = _mm_cvtepi8_epi16_from_ptr(inptr0);
__m128i r1 = _mm_cvtepi8_epi16_from_ptr(inptr1);
__m128i r2 = _mm_cvtepi8_epi16_from_ptr(inptr2);
__m128i r3 = _mm_cvtepi8_epi16_from_ptr(inptr3);
__m128i r01l = _mm_unpacklo_epi32(r0, r1);
__m128i r01h = _mm_unpackhi_epi32(r0, r1);
__m128i r23l = _mm_unpacklo_epi32(r2, r3);
__m128i r23h = _mm_unpackhi_epi32(r2, r3);
__m128i out_0_4 = _mm_unpacklo_epi64(r01l, r23l);
__m128i out_1_5 = _mm_unpackhi_epi64(r01l, r23l);
__m128i out_2_6 = _mm_unpacklo_epi64(r01h, r23h);
__m128i out_3_7 = _mm_unpackhi_epi64(r01h, r23h);
_mm_storeu_si128((__m128i*)(outptr + 0), out_0_4);
_mm_storeu_si128((__m128i*)(outptr + 8), out_1_5);
_mm_storeu_si128((__m128i*)(outptr + 16), out_2_6);
_mm_storeu_si128((__m128i*)(outptr + 24), out_3_7);
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline __v8si _m256_continue_mask_v8si(const int& x) {
static __v8si map[9] = {
{00, 00, 00, 00, 00, 00, 00, 00},
{-1, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, 00},
{-1, -1, -1, -1, -1, -1, -1, -1}};
return map[x];
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline __m256i _m256_continue_mask(const int& x) {
return (__m256i)_m256_continue_mask_v8si(x);
}
MEGDNN_ATTRIBUTE_TARGET("sse2")
static inline __m128i _mm_continue_mask(const int& x) {
static __v16qi map[17] = {
{00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
};
return (__m128i)map[x];
}
MEGDNN_ATTRIBUTE_TARGET("sse2")
static inline void transpose_4xk_int8_to_int16_pad(
const int8_t* inptr0, const int8_t* inptr1, const int8_t* inptr2,
const int8_t* inptr3, int16_t* outptr, int k) {
int i = 0;
constexpr int k_step = 2;
const int k_end = k / k_step * k_step;
const int k_remain = k - k_end;
for (; i < k_end; i += k_step) {
*outptr++ = (int16_t)(*inptr0++);
*outptr++ = (int16_t)(*inptr0++);
*outptr++ = (int16_t)(*inptr1++);
*outptr++ = (int16_t)(*inptr1++);
*outptr++ = (int16_t)(*inptr2++);
*outptr++ = (int16_t)(*inptr2++);
*outptr++ = (int16_t)(*inptr3++);
*outptr++ = (int16_t)(*inptr3++);
}
if (k_remain > 0) {
*outptr++ = (int16_t)(*inptr0++);
*outptr++ = 0;
*outptr++ = (int16_t)(*inptr1++);
*outptr++ = 0;
*outptr++ = (int16_t)(*inptr2++);
*outptr++ = 0;
*outptr++ = (int16_t)(*inptr3++);
*outptr++ = 0;
i += k_step;
}
}
MEGDNN_ATTRIBUTE_TARGET("sse2")
static inline void transpose_2xk_k2_pad(
const int8_t* inptr0, const int8_t* inptr1, int16_t* outptr, int k) {
int i = 0;
constexpr int k_step = 2;
const int k_end = k / k_step * k_step;
const int k_remain = k - k_end;
for (; i < k_end; i += k_step) {
*outptr++ = (int16_t)(*inptr0++);
*outptr++ = (int16_t)(*inptr0++);
*outptr++ = (int16_t)(*inptr1++);
*outptr++ = (int16_t)(*inptr1++);
}
if (k_remain > 0) {
*outptr++ = (int16_t)(*inptr0++);
*outptr++ = 0;
*outptr++ = (int16_t)(*inptr1++);
*outptr++ = 0;
i += k_step;
}
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void transpose_4x32_1_b(
const int8_t*& inptr0, const int8_t*& inptr1, const int8_t*& inptr2,
const int8_t*& inptr3, int8_t* outptr) {
__m256i R0 = _mm256_loadu_si256((__m256i*)(inptr0));
__m256i R1 = _mm256_loadu_si256((__m256i*)(inptr1));
__m256i R2 = _mm256_loadu_si256((__m256i*)(inptr2));
__m256i R3 = _mm256_loadu_si256((__m256i*)(inptr3));
__m256i R01L = _mm256_unpacklo_epi8(R0, R1);
__m256i R01H = _mm256_unpackhi_epi8(R0, R1);
__m256i R23L = _mm256_unpacklo_epi8(R2, R3);
__m256i R23H = _mm256_unpackhi_epi8(R2, R3);
__m256i Out0_3 = _mm256_unpacklo_epi16(R01L, R23L);
__m256i Out4_7 = _mm256_unpackhi_epi16(R01L, R23L);
__m256i Out8_11 = _mm256_unpacklo_epi16(R01H, R23H);
__m256i Out12_15 = _mm256_unpackhi_epi16(R01H, R23H);
_mm_storeu_si128((__m128i*)outptr, _mm256_extracti128_si256(Out0_3, 0));
_mm_storeu_si128((__m128i*)(outptr + 16), _mm256_extracti128_si256(Out4_7, 0));
_mm_storeu_si128((__m128i*)(outptr + 32), _mm256_extracti128_si256(Out8_11, 0));
_mm_storeu_si128((__m128i*)(outptr + 48), _mm256_extracti128_si256(Out12_15, 0));
_mm_storeu_si128((__m128i*)(outptr + 64), _mm256_extracti128_si256(Out0_3, 1));
_mm_storeu_si128((__m128i*)(outptr + 80), _mm256_extracti128_si256(Out4_7, 1));
_mm_storeu_si128((__m128i*)(outptr + 96), _mm256_extracti128_si256(Out8_11, 1));
_mm_storeu_si128((__m128i*)(outptr + 112), _mm256_extracti128_si256(Out12_15, 1));
inptr0 += 32;
inptr1 += 32;
inptr2 += 32;
inptr3 += 32;
}
template <typename T>
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void transpose_4x16_1_b(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_4x16_1_h only support uint8_t and int8_t");
__m128i R0 = _mm_loadu_si128((__m128i*)inptr0);
__m128i R1 = _mm_loadu_si128((__m128i*)inptr1);
__m128i R2 = _mm_loadu_si128((__m128i*)inptr2);
__m128i R3 = _mm_loadu_si128((__m128i*)inptr3);
__m128i R01L = _mm_unpacklo_epi8(R0, R1);
__m128i R01H = _mm_unpackhi_epi8(R0, R1);
__m128i R23L = _mm_unpacklo_epi8(R2, R3);
__m128i R23H = _mm_unpackhi_epi8(R2, R3);
__m128i Out0_3 = _mm_unpacklo_epi16(R01L, R23L);
__m128i Out4_7 = _mm_unpackhi_epi16(R01L, R23L);
__m128i Out8_11 = _mm_unpacklo_epi16(R01H, R23H);
__m128i Out12_15 = _mm_unpackhi_epi16(R01H, R23H);
_mm_storeu_si128((__m128i*)outptr, Out0_3);
_mm_storeu_si128((__m128i*)(outptr + 16), Out4_7);
_mm_storeu_si128((__m128i*)(outptr + 32), Out8_11);
_mm_storeu_si128((__m128i*)(outptr + 48), Out12_15);
inptr0 += 16;
inptr1 += 16;
inptr2 += 16;
inptr3 += 16;
}
MEGDNN_ATTRIBUTE_TARGET("sse3")
static inline void transpose_4x12_1_b_add_128(
const int8_t*& inptr0, const int8_t*& inptr1, const int8_t*& inptr2,
const int8_t*& inptr3, uint8_t*& outptr) {
__m128i const_128 = _mm_set1_epi8(-128);
__m128i R0 = _mm_loadu_si128((__m128i*)inptr0);
__m128i R1 = _mm_loadu_si128((__m128i*)inptr1);
__m128i R2 = _mm_loadu_si128((__m128i*)inptr2);
__m128i R3 = _mm_loadu_si128((__m128i*)inptr3);
R0 = _mm_add_epi8(R0, const_128);
R1 = _mm_add_epi8(R1, const_128);
R2 = _mm_add_epi8(R2, const_128);
R3 = _mm_add_epi8(R3, const_128);
__m128i R01L = _mm_unpacklo_epi8(R0, R1);
__m128i R01H = _mm_unpackhi_epi8(R0, R1);
__m128i R23L = _mm_unpacklo_epi8(R2, R3);
__m128i R23H = _mm_unpackhi_epi8(R2, R3);
__m128i Out0_3 = _mm_unpacklo_epi16(R01L, R23L);
__m128i Out4_7 = _mm_unpackhi_epi16(R01L, R23L);
__m128i Out8_11 = _mm_unpacklo_epi16(R01H, R23H);
_mm_storeu_si128((__m128i*)outptr, Out0_3);
_mm_storeu_si128((__m128i*)(outptr + 16), Out4_7);
_mm_storeu_si128((__m128i*)(outptr + 32), Out8_11);
inptr0 += 12;
inptr1 += 12;
inptr2 += 12;
inptr3 += 12;
}
} }