#include "src/arm_common/conv_bias/int8/stride1.h"
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
using namespace megdnn;
using namespace arm_common;
using namespace direct_int8_stride1;
namespace {
bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
return param.osz[1] % 8;
}
bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) {
return true;
}
return need_dst_copy(param);
}
void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2,
size_t& IW2, size_t& OH2, size_t& OW2) {
auto&& fm = param.filter_meta;
auto SW = fm.stride[1];
auto OH = param.osz[0];
auto OW = param.osz[1];
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
OH2 = OH;
OW2 = (OW + 7) & ~7;
IH2 = SW * OH + FH - SW;
IW2 = SW * OW2 + FW - SW;
}
} bool direct_int8_stride1::can_conv_direct_stride1_int8(const NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool avaible =
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) ||
(param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32)) &&
fm.format == param::Convolution::Format::NCHW && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (param.bias_type.valid()) {
avaible &=
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32) ||
(param.bias_type.enumv() == param.dst_type.enumv()));
}
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) ||
((FH == 3 || FH == 5 || FH == 7) &&
(OC <= 16 || (IC <= 4 && OC <= 32)))) &&
param.bias_mode != BiasMode::BIAS;
return avaible && preferred;
}
WorkspaceBundle direct_int8_stride1::get_bundle(
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) {
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t group = fm.group, batch = param.n;
size_t IC = fm.icpg;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(param, IH2, IW2, OH2, OW2);
size_t src_size = 0, dst_size = 0;
if (need_src_copy(param)) {
src_size = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
};
if (need_dst_copy(param)) {
dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads;
}
if (IC > 1) {
size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads;
return {nullptr, {src_size, dst_size, temp_size}};
} else {
return {nullptr, {src_size, dst_size}};
};
}
void direct_int8_stride1::copy_padding_kern(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
channel_id = workspace_ids[2], group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, channel_id));
if (need_src_copy_var) {
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(
sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(int8_t) * IW);
}
}
};
template <size_t filter, BiasMode bias_mode, typename Op>
void direct_int8_stride1::do_conv_kern(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy(kern_param);
bool need_dst_copy_var = need_dst_copy(kern_param);
bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
Op op = Op(1.0f, 4.0f);
if (need_post_process) {
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
oc = workspace_ids[2];
size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
const int8_t* fptr = kern_param.filter<dt_int8>(group_id) + oc * FH * FW * IC;
void* dst = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(kern_param.dst<void>(batch_id, group_id, oc)));
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id, oc);
if (need_src_copy_var) {
sptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
}
void* dptr = nullptr;
int32_t* tptr = nullptr;
if (need_dst_copy_var) {
dptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(bundle.get(1)) +
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size());
} else {
dptr = dst;
}
#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \
conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw< \
first_ic, last_ic, bias_mode, Op>( \
sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \
static_cast<int8_t*>(dptr), IH2, IW2, OH2, OW2, op)
#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \
conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw< \
first_ic, last_ic, bias_mode, Op>( \
sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \
static_cast<int32_t*>(dptr), nullptr, IH2, IW2, OH2, OW2, op)
#define KERN1_NEED_POST_PROCESS(filter) \
KERN0_NEED_POST_PROCESS(filter, true, false); \
for (ic = 1; ic < IC - 1; ++ic) { \
KERN0_NEED_POST_PROCESS(filter, false, false); \
} \
KERN0_NEED_POST_PROCESS(filter, false, true);
#define KERN1_NO_POST_PROCESS(filter) \
KERN0_NO_POST_PROCESS(filter, true, false); \
for (ic = 1; ic < IC; ++ic) { \
KERN0_NO_POST_PROCESS(filter, false, false); \
}
if (need_post_process) {
size_t ic = 0;
if (IC == 1) {
DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true)
} else {
tptr = static_cast<int32_t*>(bundle.get(2)) +
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size();
DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS)
}
} else {
size_t ic = 0;
if (IC == 1) {
DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false)
} else {
DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS)
}
}
#undef KERN0
#undef KERN1_NEED_POST_PROCESS
#undef KERN1_NO_POST_PROCESS
if (need_dst_copy_var) {
rep(oh, OH) {
std::memcpy(
reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(dst) +
oh * OW * kern_param.dst_type.size()),
reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(dptr) +
oh * OW2 * kern_param.dst_type.size()),
kern_param.dst_type.size() * OW);
}
}
}
SmallVector<ConvBiasImpl::NCBKern> direct_int8_stride1::get_kimpls(
const NCBKernSizeParam& param, bool m_large_group) {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
WorkspaceBundle bundle = get_bundle(param, m_large_group);
conv_fun do_conv_fun = nullptr;
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \
do_conv_fun = do_conv_kern<filter, bias_mode, op>;
#define GET_OP_PARAM(i, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(i, bias_mode, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(i, bias_mode, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(i, bias_mode, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(i) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(i, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
if (m_large_group) {
auto exec_one_group = [bundle, do_conv_fun](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) {
copy_padding_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
do_conv_fun(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, do_conv_fun](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}