#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h"
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h"
#include "src/arm_common/conv_bias/int8x8x16/conv_direct.h"
#include "src/arm_common/conv_bias/int8x8x16/conv_stride2.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_kimpl)
using namespace megdnn;
using namespace arm_common;
namespace {
bool need_dst_copy_str1(
const megdnn::fallback::ConvolutionImpl::NCBKernSizeParam& param) {
if (param.osz[0] % 1 != 0 || param.osz[1] % 8 != 0)
return true;
return false;
}
bool need_src_copy_str1(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
if (fm.padding[0] != 0 || fm.padding[1] != 0)
return true;
return need_dst_copy_str1(param);
}
void get_rectified_size_str1(
size_t IH, size_t IW, size_t OH, size_t OW, size_t PH, size_t PW, size_t& IH2,
size_t& IW2, size_t& OH2, size_t& OW2) {
OH2 = OH;
OW2 = (OW + 7) & ~7;
IH2 = OH2 + (IH - OH) + 2 * PH;
IW2 = OW2 + (IW - OW) + 2 * PW;
}
bool need_dst_copy_str2(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
if (param.osz[0] % 8 != 0 || param.osz[1] % 8 != 0)
return true;
return false;
}
bool need_src_copy_str2(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
if (fm.padding[0] != 0 || fm.padding[1] != 0)
return true;
return need_dst_copy_str2(param);
}
void get_rectified_size_str2(
size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH,
size_t PW, size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) {
MEGDNN_MARK_USED_VAR(PH);
MEGDNN_MARK_USED_VAR(PW);
OH2 = (OH + 7) & ~7;
OW2 = (OW + 7) & ~7;
IH2 = 2 * OH2 + FH - 2;
IW2 = 2 * OW2 + FW - 2;
IH2 = std::max(IH2, IH);
IW2 = std::max(IW2, IW);
}
}
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::usable"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip &&
param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16 && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == 1 &&
fm.stride[1] == 1 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
}
MIDOUT_END();
return false;
}
WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle(
const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t group = fm.group, batch = param.n;
auto IC = fm.icpg, IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto PH = fm.padding[0], PW = fm.padding[1];
size_t OH2, OW2, IH2, IW2;
bool large_group = group >= param.nr_threads;
get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
size_t part0 = 0u, part1 = 0u;
if (need_src_copy_str1(param)) {
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
}
if (need_dst_copy_str1(param)) {
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16;
}
return {nullptr, {part0, part1}};
}
size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) {
auto bundle = get_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
void ConvBiasImpl::AlgoI8x8x16Direct::copy_padding_kern(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
size_t OH2, OW2, IH2, IW2;
get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy_str1(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
channel_id = workspace_ids[2];
size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = kern_param.src<int8_t>(batch_id, group_id, channel_id);
if (need_src_copy_var) {
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(
sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(int8_t) * IW);
}
}
};
void ConvBiasImpl::AlgoI8x8x16Direct::do_conv_kern(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
size_t OH2, OW2, IH2, IW2;
get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy_str1(kern_param);
bool need_dst_copy_var = need_dst_copy_str1(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
using Func = std::function<void(
const int8_t*, const int8_t*, int16_t*, size_t, size_t, size_t, size_t,
size_t, size_t)>;
Func fun_not_add_to_dst = nullptr, fun_add_to_dst = nullptr;
if (FH == 2) {
fun_not_add_to_dst = conv_bias::conv_direct_2x2_sc_int8_int8_int16<false>;
fun_add_to_dst = conv_bias::conv_direct_2x2_sc_int8_int8_int16<true>;
} else if (FH == 3) {
fun_not_add_to_dst = conv_bias::conv_direct_3x3_sc_int8_int8_int16<false>;
fun_add_to_dst = conv_bias::conv_direct_3x3_sc_int8_int8_int16<true>;
} else if (FH == 5) {
fun_not_add_to_dst = conv_bias::conv_direct_5x5_sc_int8_int8_int16<false>;
fun_add_to_dst = conv_bias::conv_direct_5x5_sc_int8_int8_int16<true>;
}
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
oc = workspace_ids[2];
size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
const int8_t* filter = kern_param.filter<dt_int8>(group_id) + oc * FH * FW * IC;
int16_t* dst = kern_param.dst<dt_int16>(batch_id, group_id, oc);
if (need_src_copy_var) {
sptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
}
int16_t* dptr = nullptr;
if (need_dst_copy_var) {
dptr = static_cast<int16_t*>(bundle.get(1)) + ncb_index.thread_id * OH2 * OW2;
} else {
dptr = dst;
}
fun_not_add_to_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0);
for (size_t ic = 1; ic < IC; ++ic) {
fun_add_to_dst(
sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2, OW2,
0, 0);
}
if (need_dst_copy_var) {
rep(oh, OH) {
std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(int16_t) * OW);
}
}
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle = get_bundle(param);
SmallVector<NCBKern> ret_kerns;
if (large_group) {
auto exec_one_group = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) {
copy_padding_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
do_conv_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::usable"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip &&
param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5);
}
MIDOUT_END();
return false;
}
WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle(
const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t group = fm.group, batch = param.n;
auto IC = fm.icpg, IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto PH = fm.padding[0], PW = fm.padding[1];
auto FH = fm.spatial[0], FW = fm.spatial[1];
size_t OH2, OW2, IH2, IW2;
get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
size_t part0 = 0u, part1 = 0u;
bool large_group = group >= param.nr_threads;
if (need_src_copy_str2(param)) {
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
}
if (need_dst_copy_str2(param)) {
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16;
}
return {nullptr, {part0, part1}};
}
size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) {
auto bundle = get_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
void ConvBiasImpl::AlgoI8x8x16Stride2::copy_padding_kern(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
auto FH = kern_param.filter_meta.spatial[0], FW = kern_param.filter_meta.spatial[1];
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy_str2(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
channel_id = workspace_ids[2];
size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = kern_param.src<int8_t>(batch_id, group_id, channel_id);
if (need_src_copy_var) {
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(
sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(int8_t) * IW);
}
}
};
void ConvBiasImpl::AlgoI8x8x16Stride2::do_conv_kern(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool need_src_copy_var = need_src_copy_str2(kern_param);
bool need_dst_copy_var = need_dst_copy_str2(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
using Func = std::function<void(
const int8_t*, const int8_t*, int16_t*, size_t, size_t, size_t, size_t,
size_t, size_t)>;
Func fun_not_add_to_dst = nullptr, fun_add_to_dst = nullptr;
if (FH == 2) {
fun_not_add_to_dst = conv_bias::conv_stride2_2x2_sc_int8_int8_int16<false>;
fun_add_to_dst = conv_bias::conv_stride2_2x2_sc_int8_int8_int16<true>;
} else if (FH == 3) {
fun_not_add_to_dst = conv_bias::conv_stride2_3x3_sc_int8_int8_int16<false>;
fun_add_to_dst = conv_bias::conv_stride2_3x3_sc_int8_int8_int16<true>;
} else if (FH == 5) {
fun_not_add_to_dst = conv_bias::conv_stride2_5x5_sc_int8_int8_int16<false>;
fun_add_to_dst = conv_bias::conv_stride2_5x5_sc_int8_int8_int16<true>;
}
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
oc = workspace_ids[2];
size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
const int8_t* filter = kern_param.filter<dt_int8>(group_id) + oc * FH * FW * IC;
int16_t* dst = kern_param.dst<dt_int16>(batch_id, group_id, oc);
if (need_src_copy_var) {
sptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
}
int16_t* dptr = nullptr;
if (need_dst_copy_var) {
dptr = static_cast<int16_t*>(bundle.get(1)) + ncb_index.thread_id * OH2 * OW2;
} else {
dptr = dst;
}
fun_not_add_to_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0);
for (size_t ic = 1; ic < IC; ++ic) {
fun_add_to_dst(
sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2, OW2,
0, 0);
}
if (need_dst_copy_var) {
rep(oh, OH) {
std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(int16_t) * OW);
}
}
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle = get_bundle(param);
SmallVector<NCBKern> ret_kerns;
if (large_group) {
auto exec_one_group = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) {
copy_padding_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
do_conv_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}
bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy ) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) {
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY && param.nr_threads == 1_z &&
conv_bias::can_conv_int8x8x16_stride2_flt2(param);
}
MIDOUT_END();
return false;
}
size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) {
return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2(param);
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2Filter2::
dispatch_kerns(const NCBKernSizeParam& param) const {
auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) {
auto ncb_param = param;
ncb_param.src_ptr += param.src_offset(0, ncb_index.ndrange_id[0]);
ncb_param.dst_ptr += param.dst_offset(0, ncb_index.ndrange_id[0]);
ncb_param.filter_ptr += param.filter_offset(ncb_index.ndrange_id[0]);
ncb_param.bias_ptr += param.bias_offset(0, ncb_index.ndrange_id[0]);
conv_bias::conv_int8x8x16_stride2_flt2(ncb_param);
}
MIDOUT_END();
};
size_t group = param.filter_meta.group;
return {{kern, {group, 1_z, 1_z}}};
}
bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool avaible =
(param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16) &&
fm.format == param::Convolution::Format::NCHW44 &&
param.bias_mode != megdnn::BiasMode::BIAS &&
param.nonlineMode == megdnn::NonlineMode::IDENTITY && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
(fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2)) &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && fm.icpg == 1 &&
fm.ocpg == 1 && fm.group % 4 == 0;
return avaible;
}
size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
megdnn_assert(stride_h == stride_w);
if (stride_h == 1) {
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param)
.total_size_in_bytes();
} else if (stride_h == 2) {
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param)
.total_size_in_bytes();
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::
AlgoS8x8x16ChanWiseStride1Stride2NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
if (stride_h == stride_w && stride_h == 1) {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv(
"AlgoS8x8x16ChanWiseStride1Stride2NCHW44_dispatch_kerns"_hash)) {
return channel_wise_nchw44_8x8x16::stride1::get_kimpls(param);
}
MIDOUT_END();
return {};
} else if (stride_h == stride_w && stride_h == 2) {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoS8x8x16ChanWiseStride2NCHW44_dispatch_kerns"_hash)) {
return channel_wise_nchw44_8x8x16::stride2::get_kimpls(param);
}
MIDOUT_END();
return {};
} else {
return {};
}
}