#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h"
#include "src/x86/conv_bias/int8/avx2_chanwise_kern.h"
#include "src/x86/elemwise_op.h"
namespace megdnn {
namespace x86 {
namespace avx2_chanwise_stride1 {
template <size_t filter, BiasMode bias_mode, bool is_quantized, typename Op>
void conv_kimpl(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy(kern_param);
bool need_dst_copy_var = need_dst_copy(kern_param);
bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
Op op = Op(1.0f, 4.0f);
if (need_post_process) {
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
size_t padding_group_size = IH2 * IW2;
size_t workspace_group_id = ncb_index.thread_id;
size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
const int8_t* fptr = kern_param.filter<dt_int8>(group_id);
void* dst = kern_param.dst<void>(batch_id, group_id);
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id);
if (need_src_copy_var) {
sptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size;
}
void* dptr = nullptr;
int32_t* tptr = nullptr;
if (need_dst_copy_var) {
dptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(bundle.get(1)) +
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size());
} else {
dptr = dst;
}
#define KERN_NEED_POST_PROCESS(filter) \
avx2_chanwise_direct_stride1_##filter##x##filter##_int8<bias_mode, true, Op>( \
sptr, fptr, bptr, tptr, static_cast<int8_t*>(dptr), IH2, IW2, OH2, OW2, \
op)
#define KERN_NO_POST_PROCESS(filter) \
avx2_chanwise_direct_stride1_##filter##x##filter##_int8<bias_mode, false, Op>( \
sptr, fptr, bptr, static_cast<int32_t*>(dptr), nullptr, IH2, IW2, OH2, \
OW2, op)
if (need_post_process) {
tptr = static_cast<int32_t*>(bundle.get(2)) +
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size();
DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS)
} else {
DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS)
}
#undef KERN_NEED_POST_PROCESS
#undef KERN_NO_POST_PROCESS
if (need_dst_copy_var) {
rep(oh, OH) {
std::memcpy(
reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(dst) +
oh * OW * kern_param.dst_type.size()),
reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(dptr) +
oh * OW2 * kern_param.dst_type.size()),
kern_param.dst_type.size() * OW);
}
}
};
SmallVector<NCBKern> get_kimpls(
const NCBKernSizeParam& kern_param, const WorkspaceBundle& bundle) {
MEGDNN_MARK_USED_VAR(kern_param);
auto fm = kern_param.filter_meta;
size_t group = fm.group;
size_t n = kern_param.n;
SmallVector<NCBKern> ncb_kerns;
conv_fun do_conv_fun = nullptr;
#define DO_CONV_KERN_FUN(filter, bias_mode, is_quantized, op) \
do_conv_fun = conv_kimpl<filter, bias_mode, is_quantized, op>;
#define GET_OP_PARAM(i, bias_mode, is_quantized) \
switch (kern_param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN( \
i, bias_mode, is_quantized, \
TypeCvtOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN( \
i, bias_mode, is_quantized, \
ReluOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN( \
i, bias_mode, is_quantized, \
HSwishOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(i, is_quantized) \
switch (kern_param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(i, BiasMode::NO_BIAS, is_quantized) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS, is_quantized) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_QUANTIZED(i) \
switch (kern_param.dst_type.enumv()) { \
case DTypeEnum::QuantizedS8: \
GET_BIAS_MODE_PARAM(i, true) \
break; \
case DTypeEnum::QuantizedS32: \
GET_BIAS_MODE_PARAM(i, false) \
break; \
case DTypeEnum::Int32: \
GET_BIAS_MODE_PARAM(i, false) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (kern_param.filter_meta.spatial[0]) { \
case 2: \
GET_QUANTIZED(2) \
break; \
case 3: \
GET_QUANTIZED(3) \
break; \
case 5: \
GET_QUANTIZED(5) \
break; \
case 7: \
GET_QUANTIZED(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
auto exec_one_group = [bundle = bundle, do_conv_fun](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
copy_padding_kern(bundle, kern_param, ncb_index);
do_conv_fun(bundle, kern_param, ncb_index);
};
ncb_kerns.push_back({exec_one_group, {group, n, 1_z}});
return ncb_kerns;
}
} } }