#include "src/x86/conv_bias/f32/algos.h"
#include <unordered_map>
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/fallback/convolution/img2col_helper.h"
#include "src/x86/conv_bias/f32/do_conv_stride2.h"
#include "src/x86/conv_bias/opr_impl.h"
#include "src/x86/conv_bias/postprocess_helper.h"
#include "src/x86/convolution/convolution_direct_special_cases.h"
#include "src/x86/handle.h"
#include "midout.h"
using namespace megdnn;
using namespace x86;
namespace {
bool need_dst_copy(const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
if (param.osz[0] % 8 != 0 || param.osz[1] % 8 != 0) {
return true;
}
return false;
}
bool need_src_copy(const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) {
return true;
}
return need_dst_copy(param);
}
void get_rectified_size(
size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH,
size_t PW, size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) {
MEGDNN_MARK_USED_VAR(PH);
MEGDNN_MARK_USED_VAR(PW);
OH2 = (OH + 7) & ~7;
OW2 = (OW + 7) & ~7;
IH2 = 2 * OH2 + FH - 2;
IW2 = 2 * OW2 + FW - 2;
IH2 = std::max(IH2, IH);
IW2 = std::max(IW2, IW);
}
}
#define GET_KERN \
auto fm = param.filter_meta; \
size_t N = param.n; \
size_t IC = param.filter_meta.icpg; \
size_t OC = param.filter_meta.ocpg; \
size_t group = fm.group; \
bool large_group = group >= param.nr_threads; \
WorkspaceBundle bundle = get_bundle(param); \
SmallVector<NCBKern> ret_kerns; \
if (large_group) { \
auto exec_one_group = [bundle]( \
const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) mutable { \
bundle.set(kern_param.workspace_ptr); \
auto fm = kern_param.filter_meta; \
size_t IC = fm.icpg; \
size_t OC = fm.ocpg; \
for (size_t ic = 0; ic < IC; ic++) { \
copy_padding_kern( \
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); \
} \
for (size_t oc = 0; oc < OC; oc++) { \
do_conv_kern( \
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); \
} \
}; \
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \
} else { \
auto copy_padding = [bundle]( \
const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) mutable { \
bundle.set(kern_param.workspace_ptr); \
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); \
}; \
ret_kerns.push_back({copy_padding, {group, N, IC}}); \
auto do_conv = [bundle]( \
const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) mutable { \
bundle.set(kern_param.workspace_ptr); \
do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); \
}; \
ret_kerns.push_back({do_conv, {group, N, OC}}); \
} \
return ret_kerns;
bool ConvBiasImpl::AlgoDirect::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
return fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.spatial[0] <= 7 && fm.stride[0] == 1 &&
fm.stride[1] == 1;
}
WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle(
const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t group = fm.group, batch = param.n;
auto IC = fm.icpg, IH = param.isz[0], IW = param.isz[1];
auto FH = fm.spatial[0], FW = fm.spatial[1];
auto OH = param.osz[0], OW = param.osz[1];
size_t OH2, OW2, IH2, IW2;
get_rectified_img_size(
IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1], IH2, IW2, OH2, OW2);
size_t part0 = 0u, part1 = 0u;
bool large_group = group >= param.nr_threads;
if (IH != IH2 || IW != IW2) {
part0 = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads
: IC * IH2 * IW2 * sizeof(float) * group * batch;
}
if (OH != OH2 || OW != OW2) {
part1 = OH2 * OW2 * sizeof(float) * nr_threads;
}
return {nullptr, {part0, part1}};
}
size_t ConvBiasImpl::AlgoDirect::get_workspace(const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
void ConvBiasImpl::AlgoDirect::copy_padding_kern(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t GROUP = kern_param.filter_meta.group;
size_t OH2, OW2, IH2, IW2;
get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
bool rectify_src = (IH != IH2 || IW != IW2);
size_t padding_group_size = IH2 * IW2 * IC;
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
size_t channel_id = workspace_ids[2];
const float* sptr =
static_cast<const float*>(kern_param.src<float>(batch_id, group_id)) +
channel_id * IH * IW;
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
workspace_channel_id = workspace_ids[2];
if (rectify_src) {
float* sptr_base = static_cast<float*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
workspace_channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(float) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(
sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(float) * IW);
}
}
};
#define DISPATCH \
if (is_supported(SIMDType::FMA)) { \
DISPATCH_SIMD(fma) \
} else if (is_supported(SIMDType::AVX)) { \
DISPATCH_SIMD(avx) \
} else if (is_supported(SIMDType::SSE)) { \
DISPATCH_SIMD(sse) \
} else { \
megdnn_throw("no fma/avx/sse detected"); \
}
#define DISPATCH_SIMD(simd) \
if (is_xcorr) { \
DISPATCH_SIMD_MODE(simd, xcorr) \
} else { \
DISPATCH_SIMD_MODE(simd, conv) \
}
#define DISPATCH_SIMD_MODE(simd, mode) \
switch (FH) { \
case 1: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 1); \
break; \
case 2: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 2); \
break; \
case 3: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 3); \
break; \
case 4: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 4); \
break; \
case 5: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 5); \
break; \
case 6: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 6); \
break; \
case 7: \
DISPATCH_SIMD_MODE_FSIZE(simd, mode, 7); \
break; \
default: \
megdnn_throw("unsupported filter size"); \
}
#define DISPATCH_SIMD_MODE_FSIZE(simd, mode, fsize) \
func = detail::convolution_##mode##_fh##fsize##_##simd;
void ConvBiasImpl::AlgoDirect::do_conv_kern(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
auto is_xcorr = !kern_param.filter_meta.should_flip;
size_t GROUP = kern_param.filter_meta.group;
size_t OH2, OW2, IH2, IW2;
get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
bool rectify_src = (IH != IH2 || IW != IW2);
bool rectify_dst = (OH != OH2 || OW != OW2);
size_t padding_group_size = IH2 * IW2 * IC;
std::function<void(
const float*, const float*, float*, size_t, size_t, size_t, size_t, size_t)>
func = nullptr;
DISPATCH;
size_t bias_offset = 0;
if (kern_param.bias_mode == megdnn::BiasMode::BIAS) {
bias_offset = OH * OW;
} else if (kern_param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
bias_offset = 1_z;
}
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
oc = workspace_ids[2];
const float* sptr = kern_param.src<float>(batch_id, group_id);
const float* filter = kern_param.filter<float>(group_id) + oc * FH * FW * IC;
const float* bias_ptr =
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset;
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW;
if (rectify_src) {
sptr = static_cast<float*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
}
float* dptr = nullptr;
if (rectify_dst) {
dptr = static_cast<float*>(bundle.get(1)) + ncb_index.thread_id * OH2 * OW2;
} else {
dptr = dst;
}
std::memset(dptr, 0, sizeof(float) * OH2 * OW2);
rep(ic, IC) {
func(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2, OW2,
FW);
}
if (rectify_dst) {
rep(oh, OH) { std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(float) * OW); }
}
PostProcess<dt_float32>::run(
dst, const_cast<float*>(bias_ptr), dst, kern_param.bias_mode,
kern_param.nonlineMode, kern_param.bias_type, kern_param.dst_type, 1_z, 1_z,
OH, OW);
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirect::get_kimpls(
const NCBKernSizeParam& param) const {
GET_KERN;
}
bool ConvBiasImpl::AlgoDirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle(
const NCBKernSizeParam& param) const {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t nr_threads = param.nr_threads;
size_t group = param.filter_meta.group;
size_t batch = param.n;
size_t src_size = 0, dst_size = 0;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool large_group = group >= param.nr_threads;
if (need_src_copy(param)) {
src_size = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads
: IC * IH2 * IW2 * sizeof(float) * group * batch;
}
if (need_dst_copy(param)) {
dst_size = OH2 * OW2 * sizeof(float) * nr_threads;
}
return WorkspaceBundle(nullptr, {src_size, dst_size});
}
size_t ConvBiasImpl::AlgoDirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t GROUP = kern_param.filter_meta.group;
size_t OH2, OW2, IH2, IW2;
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool rectify_src = need_src_copy(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t channel_id = workspace_ids[2];
const float* sptr =
static_cast<const float*>(kern_param.src<float>(batch_id, group_id)) +
channel_id * IH * IW;
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
workspace_channel_id = workspace_ids[2];
if (rectify_src) {
float* sptr_base = static_cast<float*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
workspace_channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(float) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(
sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(float) * IW);
}
}
};
void ConvBiasImpl::AlgoDirectStride2::do_conv_kern(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
size_t OH2, OW2, IH2, IW2;
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool rectify_src = need_src_copy(kern_param);
bool rectify_dst = need_dst_copy(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
using Func = std::function<void(
const float*, const float*, float*, size_t, size_t, size_t, size_t, size_t,
size_t)>;
Func func_no_add_dst = nullptr, func_add_dst = nullptr;
if (FH == 2) {
func_no_add_dst = conv_general_simd::do_conv_2x2_stride2<false>;
func_add_dst = conv_general_simd::do_conv_2x2_stride2<true>;
} else if (FH == 3) {
func_no_add_dst = conv_general_simd::do_conv_3x3_stride2<false>;
func_add_dst = conv_general_simd::do_conv_3x3_stride2<true>;
} else if (FH == 5) {
func_no_add_dst = conv_general_simd::do_conv_5x5_stride2<false>;
func_add_dst = conv_general_simd::do_conv_5x5_stride2<true>;
} else if (FH == 7) {
func_no_add_dst = conv_general_simd::do_conv_7x7_stride2<false>;
func_add_dst = conv_general_simd::do_conv_7x7_stride2<true>;
}
size_t bias_offset = 0;
if (kern_param.bias_mode == megdnn::BiasMode::BIAS) {
bias_offset = OH * OW;
} else if (kern_param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
bias_offset = 1_z;
}
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1],
oc = workspace_ids[2];
const float* sptr = kern_param.src<float>(batch_id, group_id);
const float* filter = kern_param.filter<float>(group_id) + oc * FH * FW * IC;
const float* bias_ptr =
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset;
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW;
if (rectify_src) {
sptr = static_cast<float*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
}
float* dptr = nullptr;
if (rectify_dst) {
dptr = static_cast<float*>(bundle.get(1)) + ncb_index.thread_id * OH2 * OW2;
} else {
dptr = dst;
}
func_no_add_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0);
for (size_t ic = 1; ic < IC; ++ic) {
func_add_dst(
sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2, OW2,
0, 0);
}
if (rectify_dst) {
rep(oh, OH) { std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(float) * OW); }
}
PostProcess<dt_float32>::run(
dst, const_cast<float*>(bias_ptr), dst, kern_param.bias_mode,
kern_param.nonlineMode, kern_param.bias_type, kern_param.dst_type, 1_z, 1_z,
OH, OW);
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
GET_KERN;
}
#if MEGDNN_X86_WITH_MKL_DNN
static inline void mkldnn_fp32_conv_instance(
const ConvBiasImpl::NCBKernParam& param, const uint32_t ocpg,
const uint32_t icpg, const uint32_t group, const uint32_t in, const uint32_t ic,
const uint32_t oc, const uint32_t ih, const uint32_t iw, const uint32_t kh,
const uint32_t kw, const uint32_t pad_h, const uint32_t pad_w,
const uint32_t stride_h, const uint32_t stride_w, const uint32_t oh,
const uint32_t ow, std::vector<dnnl::primitive>& net,
std::vector<std::unordered_map<int, dnnl::memory>>& net_args,
dnnl::engine& eng_mkldnn) {
dnnl::memory::dims src_shape = {in, ic, ih, iw};
dnnl::memory::dims weight_shape = {oc, ic, kh, kw};
dnnl::memory::dims bias_shape = {oc};
dnnl::memory::dims dst_shape = {in, oc, oh, ow};
dnnl::memory::dims strides_shape = {stride_h, stride_w};
dnnl::memory::dims padding_shape = {pad_h, pad_w};
auto user_src_desc = dnnl::memory::desc(
{src_shape}, dnnl::memory::data_type::f32,
dnnl::memory::format_tag::nChw8c);
if (group == 1 && ic < 8) {
user_src_desc = dnnl::memory::desc(
{src_shape}, dnnl::memory::data_type::f32,
dnnl::memory::format_tag::nchw);
}
auto user_src_mem =
dnnl::memory(user_src_desc, eng_mkldnn, param.src_ptr.get_ptr());
auto weight_tag = dnnl::memory::format_tag::OIhw8i8o;
if (group > 1) {
weight_shape = {group, ocpg, icpg, kh, kw};
if (oc == group && ic == group) {
weight_tag = dnnl::memory::format_tag::Goihw8g;
} else {
weight_tag = dnnl::memory::format_tag::gOIhw8i8o;
}
} else if (group == 1 && ic < 8) {
weight_tag = dnnl::memory::format_tag::Ohwi8o;
}
auto user_weights_desc = dnnl::memory::desc(
{weight_shape}, dnnl::memory::data_type::f32, weight_tag);
auto user_weights_mem =
dnnl::memory(user_weights_desc, eng_mkldnn, param.filter_ptr.get_ptr());
auto user_bias_desc = dnnl::memory::desc();
if (param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
user_bias_desc = dnnl::memory::desc(
{bias_shape}, dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
}
auto user_bias_mem =
dnnl::memory(user_bias_desc, eng_mkldnn, param.bias_ptr.get_ptr());
auto user_dst_desc = dnnl::memory::desc(
{dst_shape}, dnnl::memory::data_type::f32,
dnnl::memory::format_tag::nChw8c);
auto user_dst_mem =
dnnl::memory(user_dst_desc, eng_mkldnn, param.dst_ptr.get_ptr());
auto conv_desc = dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto,
user_src_mem.get_desc(), user_weights_mem.get_desc(),
user_bias_mem.get_desc(), user_dst_mem.get_desc(), strides_shape,
padding_shape, padding_shape);
dnnl::primitive_attr attr;
if ((param.nonlineMode == NonlineMode::RELU ||
param.nonlineMode == NonlineMode::SIGMOID) &&
(param.bias_mode == megdnn::BiasMode::NO_BIAS ||
param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS)) {
auto post_tag = dnnl::algorithm::eltwise_linear;
switch (param.nonlineMode) {
case NonlineMode::RELU:
post_tag = dnnl::algorithm::eltwise_relu;
break;
case NonlineMode::SIGMOID:
post_tag = dnnl::algorithm::eltwise_logistic;
break;
default:
megdnn_assert(
0, "not supported nonline mode %d\n",
static_cast<int>(param.nonlineMode));
}
dnnl::post_ops ops;
ops.append_eltwise(1.f, post_tag, 0.f, 0.f);
attr.set_post_ops(ops);
}
auto conv_prim_desc =
dnnl::convolution_forward::primitive_desc(conv_desc, attr, eng_mkldnn);
net.push_back(dnnl::convolution_forward(conv_prim_desc));
net_args.push_back(
{{DNNL_ARG_SRC, user_src_mem},
{DNNL_ARG_WEIGHTS, user_weights_mem},
{DNNL_ARG_BIAS, user_bias_mem},
{DNNL_ARG_DST, user_dst_mem}});
}
namespace {
struct NCBKernParamEqual {
bool operator()(
const fallback::ConvBiasImpl::NCBKernParam& x,
const fallback::ConvBiasImpl::NCBKernParam& y) const {
bool flag = true;
flag = flag && (x.src_ptr == y.src_ptr);
flag = flag && (x.dst_ptr == y.dst_ptr);
flag = flag && (x.filter_ptr == y.filter_ptr);
flag = flag && (x.bias_ptr == y.bias_ptr);
flag = flag && (x.isz == y.isz);
flag = flag && (x.osz == y.osz);
flag = flag && (x.src_type == y.src_type);
flag = flag && (x.dst_type == y.dst_type);
flag = flag && (x.filter_type == y.filter_type);
flag = flag && (x.bias_type == y.bias_type);
flag = flag && (x.filter_meta == y.filter_meta);
flag = flag && (x.n == y.n);
flag = flag && (x.bias_mode == y.bias_mode);
flag = flag && (x.nonlineMode == y.nonlineMode);
flag = flag && (x.bias_bs == y.bias_bs);
return flag;
};
};
struct NCBKernParamHash {
std::size_t operator()(const fallback::ConvBiasImpl::NCBKernParam& param) const {
std::size_t result = reinterpret_cast<std::size_t>(param.filter_ptr.get_ptr());
result = result ^ (reinterpret_cast<std::size_t>(param.src_ptr.get_ptr()) << 3);
result = result ^ (reinterpret_cast<std::size_t>(param.dst_ptr.get_ptr()) << 7);
result = result ^ (static_cast<std::size_t>(param.n) << 11);
return result;
};
};
}
void ConvBiasImpl::AlgoMkldnnConv::kern_mkldnn_fp32(
const NCBKernParam& param, const NCBKernIndex&) {
const NCBKernParam& key = param;
static std::unordered_map<
NCBKernParam, std::vector<dnnl::primitive>, NCBKernParamHash,
NCBKernParamEqual>
kern_net_map;
static std::unordered_map<
NCBKernParam, std::vector<std::unordered_map<int, dnnl::memory>>,
NCBKernParamHash, NCBKernParamEqual>
kern_net_arg_map;
auto x86_handle = static_cast<HandleImpl*>(inplace_cpu_handle().get());
megdnn_assert(x86_handle != nullptr, "x86 handle can not be null");
auto eng_mkldnn = x86_handle->mkldnn_engine();
auto stream_mkldnn = x86_handle->mkldnn_stream();
auto&& fm = param.filter_meta;
const uint32_t group = fm.group;
const uint32_t in = param.n;
const uint32_t ic = fm.icpg * group;
const uint32_t oc = fm.ocpg * group;
const uint32_t ih = param.isz[0];
const uint32_t iw = param.isz[1];
const uint32_t kh = fm.spatial[0];
const uint32_t kw = fm.spatial[1];
const uint32_t pad_h = fm.padding[0];
const uint32_t pad_w = fm.padding[1];
const uint32_t stride_h = fm.stride[0];
const uint32_t stride_w = fm.stride[1];
const uint32_t oh = param.osz[0];
const uint32_t ow = param.osz[1];
if (kern_net_map.find(key) == kern_net_map.end()) {
std::vector<dnnl::primitive> net;
std::vector<std::unordered_map<int, dnnl::memory>> net_args;
mkldnn_fp32_conv_instance(
param, fm.ocpg, fm.icpg, group, in, ic, oc, ih, iw, kh, kw, pad_h,
pad_w, stride_h, stride_w, oh, ow, net, net_args, eng_mkldnn);
kern_net_map[key] = net;
kern_net_arg_map[key] = net_args;
}
const auto& net = kern_net_map[key];
const auto& net_args = kern_net_arg_map[key];
for (size_t i = 0; i < net.size(); ++i) {
net.at(i).execute(stream_mkldnn, net_args.at(i));
}
stream_mkldnn.wait();
if ((param.bias_mode == megdnn::BiasMode::NO_BIAS ||
param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) &&
(param.nonlineMode != NonlineMode::IDENTITY &&
param.nonlineMode != NonlineMode::RELU &&
param.nonlineMode != NonlineMode::SIGMOID)) {
PostProcess<float>::run(
param.dst_ptr.get_ptr(), param.bias_ptr.get_ptr(),
param.dst_ptr.get_ptr(), megdnn::BiasMode::NO_BIAS, param.nonlineMode,
param.bias_type, param.dst_type, in, oc, oh, ow);
} else if (param.bias_mode == megdnn::BiasMode::BIAS) {
PostProcess<float>::run(
param.dst_ptr.get_ptr(), param.bias_ptr.get_ptr(),
param.dst_ptr.get_ptr(), param.bias_mode, param.nonlineMode,
param.bias_type, param.dst_type, in, oc, oh, ow);
}
}
#endif