#pragma once
#include "megdnn/arch.h"
#include <immintrin.h>
#ifdef WIN32
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#endif
#if !defined(__clang__)
#pragma GCC target("avx")
#endif
namespace megdnn {
namespace x86 {
MEGDNN_ATTRIBUTE_TARGET("avx")
static inline __m256 _mm256_loadu2_m128_emulate(
const float* hiaddr, const float* loaddr) {
return _mm256_insertf128_ps(
_mm256_castps128_ps256(_mm_loadu_ps(loaddr)), _mm_loadu_ps(hiaddr), 1);
}
MEGDNN_ATTRIBUTE_TARGET("avx")
static inline void _mm256_storeu2_m128_emulate(
float* hiaddr, float* loaddr, __m256 reg) {
auto xmm0 = _mm256_extractf128_ps(reg, 0);
auto xmm1 = _mm256_extractf128_ps(reg, 1);
_mm_storeu_ps(loaddr, xmm0);
_mm_storeu_ps(hiaddr, xmm1);
}
template <typename ctype, size_t len>
struct Vector;
template <>
struct Vector<float, 8> {
__m256 value;
Vector() {}
Vector(const float v) MEGDNN_ATTRIBUTE_TARGET("avx") { value = _mm256_set1_ps(v); }
Vector(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") { value = lr.value; }
Vector(const Vector&& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
value = std::move(lr.value);
}
Vector(const __m256& v) MEGDNN_ATTRIBUTE_TARGET("avx") { value = v; }
static Vector load(const float* addr) MEGDNN_ATTRIBUTE_TARGET("avx") {
Vector v;
v.value = _mm256_loadu_ps(addr);
return v;
}
static void save(float* addr, const Vector& v) MEGDNN_ATTRIBUTE_TARGET("avx") {
_mm256_storeu_ps(addr, v.value);
}
void save(float* addr) MEGDNN_ATTRIBUTE_TARGET("avx") { save(addr, *this); }
Vector operator+(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
Vector dst;
dst.value = _mm256_add_ps(value, lr.value);
return dst;
}
Vector& operator+=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
value = _mm256_add_ps(value, lr.value);
return *this;
}
Vector operator-(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
Vector dst;
dst.value = _mm256_sub_ps(value, lr.value);
return dst;
}
Vector& operator-=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
value = _mm256_sub_ps(value, lr.value);
return *this;
}
Vector operator*(float lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
Vector dst;
dst.value = _mm256_mul_ps(value, _mm256_set1_ps(lr));
return dst;
}
Vector operator*(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
Vector dst;
dst.value = _mm256_mul_ps(value, lr.value);
return dst;
}
Vector& operator*=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
value = _mm256_mul_ps(value, lr.value);
return *this;
}
Vector& operator=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
value = lr.value;
return *this;
}
Vector& operator=(const Vector&& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
value = std::move(lr.value);
return *this;
}
Vector operator-() MEGDNN_ATTRIBUTE_TARGET("avx") {
Vector dst;
dst.value = -value;
return dst;
}
};
#if !defined(__clang__)
#pragma GCC reset_options
#endif
} }