#pragma once
#include "megdnn/arch.h"
#include "megdnn/basic_types.h"
#include <cstddef>
namespace megdnn {
namespace x86 {
namespace detail {
#define FMA_ELEMENT_OPTR_UNARY(optr_type) \
void fma_element_##optr_type( \
const TensorND& src_tensor, const TensorND& dst_tensor); \
void fma_element_##optr_type(size_t tsize, const float* src_ptr, float* dst_ptr) \
MEGDNN_ATTRIBUTE_TARGET("fma");
FMA_ELEMENT_OPTR_UNARY(sigmoid)
FMA_ELEMENT_OPTR_UNARY(exp)
FMA_ELEMENT_OPTR_UNARY(tanh)
FMA_ELEMENT_OPTR_UNARY(fast_tanh)
FMA_ELEMENT_OPTR_UNARY(relu)
FMA_ELEMENT_OPTR_UNARY(set)
#define FMA_ELEMENT_OPTR_TERNARY(optr_type) \
void fma_element_##optr_type( \
const TensorND& src0_tensor, const TensorND& src1_tensor, \
const TensorND& src2_tensor, const TensorND& dst_tensor); \
void fma_element_##optr_type( \
size_t tsize, float* src0_ptr, float* src1_ptr, float* src2_ptr, \
float* dst_ptr) MEGDNN_ATTRIBUTE_TARGET("fma");
FMA_ELEMENT_OPTR_TERNARY(fma3)
FMA_ELEMENT_OPTR_TERNARY(fma3_scalar)
#define FMA_ELEMENT_OPTR_TERNARY_1C11(optr_type) \
void fma_element_##optr_type( \
size_t batch_size, size_t channel_size, size_t channel_stride, \
const TensorND& src0_tensor, const TensorND& src1_tensor, \
const TensorND& src2_tensor, const TensorND& dst_tensor); \
void fma_element_##optr_type( \
size_t batch_size, size_t channel_size, size_t channel_stride, \
float* src0_ptr, float* src1_ptr, float* src2_ptr, float* dst_ptr) \
MEGDNN_ATTRIBUTE_TARGET("fma");
FMA_ELEMENT_OPTR_TERNARY_1C11(fma3_1c11)
} } }