#pragma once
#include "megdnn/arch.h"
#include "megdnn/basic_types.h"
#include <cstddef>
namespace megdnn {
namespace x86 {
namespace detail {
#define AVX_ELEMENT_OPTR_UNARY(optr_type) \
void avx_element_##optr_type( \
const TensorND& src_tensor, const TensorND& dst_tensor); \
void avx_element_##optr_type(size_t tsize, const float* src_ptr, float* dst_ptr) \
MEGDNN_ATTRIBUTE_TARGET("avx");
AVX_ELEMENT_OPTR_UNARY(sigmoid)
AVX_ELEMENT_OPTR_UNARY(exp)
AVX_ELEMENT_OPTR_UNARY(tanh)
AVX_ELEMENT_OPTR_UNARY(fast_tanh)
AVX_ELEMENT_OPTR_UNARY(relu)
AVX_ELEMENT_OPTR_UNARY(set)
#define AVX_ELEMENT_OPTR_BINARY(optr_type) \
void avx_element_##optr_type( \
const TensorND& src0_tensor, const TensorND& src1_tensor, \
const TensorND& dst_tensor); \
void avx_element_##optr_type( \
size_t tsize, float* src0_ptr, float* src1_ptr, float* dst_ptr) \
MEGDNN_ATTRIBUTE_TARGET("avx");
AVX_ELEMENT_OPTR_BINARY(add)
AVX_ELEMENT_OPTR_BINARY(bias_sigmoid)
AVX_ELEMENT_OPTR_BINARY(bias_relu)
AVX_ELEMENT_OPTR_BINARY(bias_tanh)
void avx_element_add_scalar(
const size_t tsize, float* src_ptr, float* dst_ptr, const float bias)
MEGDNN_ATTRIBUTE_TARGET("avx");
#define AVX_ELEMENT_OPTR_BINARY_1C11(optr_type) \
void avx_element_##optr_type( \
size_t batch_size, size_t channel_size, size_t channel_stride, \
const TensorND& src0_tensor, const TensorND& src1_tensor, \
const TensorND& dst_tensor); \
void avx_element_##optr_type( \
size_t batch_size, size_t channel_size, size_t channel_stride, \
float* src1_ptr, float* src2_ptr, float* dst_ptr) \
MEGDNN_ATTRIBUTE_TARGET("avx");
AVX_ELEMENT_OPTR_BINARY_1C11(add_1c11)
AVX_ELEMENT_OPTR_BINARY_1C11(bias_sigmoid_1c11)
AVX_ELEMENT_OPTR_BINARY_1C11(bias_relu_1c11)
#define AVX_ELEMENT_OPTR_TERNARY(optr_type) \
void avx_element_##optr_type( \
const TensorND& src0_tensor, const TensorND& src1_tensor, \
const TensorND& src2_tensor, const TensorND& dst_tensor); \
void avx_element_##optr_type( \
size_t tsize, float* src0_ptr, float* src1_ptr, float* src2_ptr, \
float* dst_ptr) MEGDNN_ATTRIBUTE_TARGET("avx");
AVX_ELEMENT_OPTR_TERNARY(fma3)
AVX_ELEMENT_OPTR_TERNARY(fma3_scalar)
#define AVX_ELEMENT_OPTR_TERNARY_1C11(optr_type) \
void avx_element_##optr_type( \
size_t batch_size, size_t channel_size, size_t channel_stride, \
const TensorND& src0_tensor, const TensorND& src1_tensor, \
const TensorND& src2_tensor, const TensorND& dst_tensor); \
void avx_element_##optr_type( \
size_t batch_size, size_t channel_size, size_t channel_stride, \
float* src0_ptr, float* src1_ptr, float* src2_ptr, float* dst_ptr) \
MEGDNN_ATTRIBUTE_TARGET("avx");
AVX_ELEMENT_OPTR_TERNARY_1C11(fma3_1c11)
void avx_element_add_in_a_channel(
float* output_ptr, float* bias_ptr, size_t channel_size, size_t channel_stride)
MEGDNN_ATTRIBUTE_TARGET("avx");
} } }