#include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_pooling)
namespace megdnn {
namespace arm_common {
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
megdnn_assert(
(param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm ||
param.src_type == dtype::Int8{}) &&
param.format == param::Pooling::Format::NCHW &&
(param.mode == param::Pooling::Mode::MAX ||
(param.mode == param::Pooling::Mode::AVERAGE && param.filter[0] == 3)) &&
param.filter[0] == param.filter[1] &&
(param.filter[0] == 3 || param.filter[1] == 5) && param.stride[0] == 2 &&
param.stride[1] == 2 && param.isz[0] >= 2 && param.isz[1] >= 2);
auto IW = param.isz[1];
auto OW = param.osz[1];
SmallVector<size_t> needed_mem;
for (size_t i = 0; i < param.filter[0]; ++i)
needed_mem.push_back(OW * param.src_type.size());
needed_mem.push_back((IW + 1) / 2 * param.src_type.size());
needed_mem.push_back((IW + 1) / 2 * param.src_type.size());
WorkspaceBundle ws(nullptr, needed_mem, 16);
return ws;
}
WorkspaceBundle get_bundle_nchw44(const PoolingImpl::PoolingKernSizeParam& param) {
megdnn_assert(
(param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.format == param::Pooling::Format::NCHW44));
auto IH = param.isz[0];
auto IW = param.isz[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
size_t padding_size = 0;
if ((PH != 0) || (PW != 0)) {
padding_size = (IW + 2 * PW) * (IH + 2 * PH) * 4 * sizeof(int8_t);
}
return WorkspaceBundle(nullptr, {padding_size});
}
const int8_t* handle_padding(
const int8_t* src, size_t IH, size_t IW, size_t& IH2, size_t& IW2, size_t PH,
size_t PW, const WorkspaceBundle& ws, bool is_max_mode) {
int8_t* sptr_base = nullptr;
int8_t padding_value = is_max_mode ? INT8_MIN : 0;
bool need_pad = ((PH != 0) || (PW != 0)) ? true : false;
if (need_pad) {
IH2 = IH + 2 * PH;
IW2 = IW + 2 * PW;
sptr_base = static_cast<int8_t*>(ws.get(0));
memset(sptr_base, padding_value, sizeof(int8_t) * IH2 * IW2 * 4);
rep(ih, IH) {
std::memcpy(
sptr_base + (ih + PH) * IW2 * 4 + PW * 4, src + ih * IW * 4,
sizeof(int8_t) * IW * 4);
}
} else {
IH2 = IH;
IW2 = IW;
}
return need_pad ? sptr_base : src;
}
bool PoolingImpl::AlgoFilterxModexStride1::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && SH == 1 && SW == 1 &&
FH == FW && (FH == 2 || FH == 3);
bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE);
return avaible && is_mode_ok;
}
void PoolingImpl::AlgoFilterxModexStride1::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto FH = param.filter[0];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \
MIDOUT_BEGIN( \
megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id), \
Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
src_dtype = param.src_type](size_t index, size_t) { \
size_t n = index / C; \
size_t c = index % C; \
do_pooling_compact<Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \
static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) + \
n * C * IH * IW + c * IH * IW, \
static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) + \
n * C * OH * OW + c * OH * OW, \
src_dtype, IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END()
#define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \
switch (FH) { \
case 2: { \
using _Pooler = Pooler<4, dtype, ctype, comp_type>; \
using _NeonPooler = NeonPooler<4, dtype, ctype, comp_type>; \
DISPATCH_FUNC(_Pooler, _NeonPooler, 2, midout_type_id); \
break; \
} \
case 3: { \
using _Pooler = Pooler<9, dtype, ctype, comp_type>; \
using _NeonPooler = NeonPooler<9, dtype, ctype, comp_type>; \
DISPATCH_FUNC(_Pooler, _NeonPooler, 3, midout_type_id); \
break; \
} \
default: \
megdnn_assert(0, "unsupport pooling filter size"); \
break; \
}
#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \
switch (param.mode) { \
case Mode::MAX: \
DISPATCH_WINDOW( \
MaxPooler, NeonMaxPooler, dtype, ctype, comp_type, \
midout_type_id); \
break; \
case Mode::AVERAGE: \
DISPATCH_WINDOW( \
MeanInPooler, NeonMeanPooler, dtype, ctype, comp_type, \
midout_type_id); \
break; \
default: \
megdnn_assert(0, "unsupport pooling mode"); \
break; \
}
if (param.src_type == dtype::Float32{}) {
DISPATCH_MODE(dt_float32, float, float, 0);
} else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
DISPATCH_MODE(dt_qint8, int8_t, float, 1);
} else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
DISPATCH_MODE(dt_quint8, uint8_t, float, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (param.src_type == dtype::Float16{}) {
DISPATCH_MODE(dt_float16, __fp16, __fp16, 3);
#endif
}
#undef DISPATCH_FUNC
#undef DISPATCH_WINDOW
#undef DISPATCH_MODE
}
bool PoolingImpl::AlgoFilter2ModexStride2::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && FH == FW && SH == SW &&
FH == 2 && SH == 2;
bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE);
return avaible && is_mode_ok;
}
void PoolingImpl::AlgoFilter2ModexStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(Pooler, mode, midout_type_id) \
MIDOUT_BEGIN( \
megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id), \
Pooler::MIDOUT_CASE_NUM) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
src_dtype = param.src_type](size_t index, size_t) { \
size_t n = index / C; \
size_t c = index % C; \
do_pooling_2x2<Pooler MEGDNN_COMMA mode>( \
static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) + \
n * C * IH * IW + c * IH * IW, \
static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) + \
n * C * OH * OW + c * OH * OW, \
src_dtype, IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END()
#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \
switch (param.mode) { \
case Mode::MAX: { \
using _Pooler = MaxPooler<4, dtype, ctype, comp_type>; \
DISPATCH_FUNC(_Pooler, Mode::MAX, midout_type_id); \
break; \
} \
case Mode::AVERAGE: { \
using _Pooler = MeanInPooler<4, dtype, ctype, comp_type>; \
DISPATCH_FUNC(_Pooler, Mode::AVERAGE, midout_type_id); \
break; \
} \
default: \
megdnn_assert(0, "unsupport pooling mode"); \
break; \
}
if (param.src_type == dtype::Float32{}) {
DISPATCH_MODE(dt_float32, float, float, 0);
} else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
DISPATCH_MODE(dt_qint8, int8_t, float, 1);
} else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
DISPATCH_MODE(dt_quint8, uint8_t, float, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (param.src_type == dtype::Float16{}) {
DISPATCH_MODE(dt_float16, __fp16, __fp16, 3);
#endif
}
#undef DISPATCH_FUNC
#undef DISPATCH_PAD
#undef DISPATCH_MODE
}
bool PoolingImpl::AlgoFilter3MaxStride2::usable(
const PoolingKernSizeParam& param) const {
bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
param.filter[0] == 3 && param.filter[1] == 3 &&
param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
param.isz[1] >= 2;
return avaible;
}
void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \
WorkspaceBundle wbundle = get_bundle(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_3x3_s2x2_##func##_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \
c * IH * IW, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \
c * OH * OW, \
IH, IW, OH, OW, PH, PW, ws); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
if (param.src_type == dtype::Float32{}) {
DISPATCH_FUNC(float, float, 0);
} else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
DISPATCH_FUNC(int8_t, int8, 1);
} else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
DISPATCH_FUNC(uint8_t, uint8, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (param.src_type == dtype::Float16{}) {
DISPATCH_FUNC(__fp16, float16, 3);
#endif
}
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter3AverageStride2::usable(
const PoolingKernSizeParam& param) const {
bool avaible = (param.src_type.category() == DTypeCategory::FLOAT) &&
param.format == Param::Format::NCHW && param.mode == Mode::AVERAGE &&
param.filter[0] == 3 && param.filter[1] == 3 &&
param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
param.isz[1] >= 2;
return avaible;
}
void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \
WorkspaceBundle wbundle = get_bundle(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_average_pooling_3x3_s2x2_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \
c * IH * IW, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \
c * OH * OW, \
IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
if (param.src_type == dtype::Float32{}) {
DISPATCH_FUNC(dt_float32, 4, 0);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (param.src_type == dtype::Float16{}) {
DISPATCH_FUNC(__fp16, 8, 1);
#endif
}
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter4MaxStride2::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
auto OH = param.osz[0], OW = param.osz[1];
bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
FH == 4 && FW == 4 && SH == 2 && SW == 2 && OH >= 2 && OW >= 2;
return avaible;
}
void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
src_dtype = param.src_type](size_t index, size_t) { \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_w4x4_s2x2_##func##_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \
c * IH * IW, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \
c * OH * OW, \
src_dtype, IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
if (param.src_type == dtype::Float32{}) {
DISPATCH_FUNC(float, float, 0);
} else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
DISPATCH_FUNC(int8_t, int8, 1);
} else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
DISPATCH_FUNC(uint8_t, uint8, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (param.src_type == dtype::Float16{}) {
DISPATCH_FUNC(__fp16, float16, 3);
#endif
}
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter5MaxStride2::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
auto OH = param.osz[0], OW = param.osz[1];
bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type.category() == DTypeCategory::QUANTIZED) &&
param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
FH == 5 && FW == 5 && SH == 2 && SW == 2 && OH >= 2 && OW >= 2;
return avaible;
}
void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \
WorkspaceBundle wbundle = get_bundle(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_w5x5_s2x2_NEON<dtype>( \
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \
c * IH * IW, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \
c * OH * OW, \
IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
if (param.src_type == dtype::Float32{}) {
DISPATCH_FUNC(dt_float32, float, 0, 4);
} else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
DISPATCH_FUNC(dt_int8, int8_t, 1, 16);
} else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
DISPATCH_FUNC(dt_uint8, uint8_t, 2, 16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (param.src_type == dtype::Float16{}) {
DISPATCH_FUNC(dt_float16, __fp16, 3, 8);
#endif
}
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoInt8Filter2MaxStride2::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
bool avaible = param.src_type == dtype::Int8() &&
param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
SH == 2 && SW == 2 && PH == 0 && PW == 0 && FH == 2 && FW == 2;
return avaible;
}
void PoolingImpl::AlgoInt8Filter2MaxStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto src_ptr = param.src<dt_int8>();
auto dst_ptr = param.dst<dt_int8>();
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(6)) {
auto run = [C, IH, IW, OH, OW, src_ptr, dst_ptr](size_t index, size_t) {
size_t n = index / C;
size_t c = index % C;
pooling_max_w2x2_s2x2(
src_ptr + n * C * IH * IW + c * IH * IW,
dst_ptr + n * C * OH * OW + c * OH * OW, 1, 1, IH, IW, OH, OW);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, run);
}
MIDOUT_END();
}
bool PoolingImpl::AlgoInt8Filter3MaxStride2::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
auto IH = param.isz[0];
auto IW = param.isz[1];
bool avaible = param.src_type == dtype::Int8() &&
param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
FH == 3 && FW == 3 && SH == 2 && SW == 2 && IH >= 2 && IW >= 2;
return avaible;
}
void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto src_ptr = param.src<dt_int8>();
auto dst_ptr = param.dst<dt_int8>();
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(7)) {
WorkspaceBundle wbundle = get_bundle(param);
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,
workspace_ptr = param.workspace<dt_byte>()](
size_t index, size_t thread_id) {
auto ws = wbundle;
ws.set(workspace_ptr + thread_id * ws.total_size_in_bytes());
size_t n = index / C;
size_t c = index % C;
do_max_pooling_3x3_s2x2_int8_NEON(
src_ptr + n * C * IH * IW + c * IH * IW,
dst_ptr + n * C * OH * OW + c * OH * OW, IH, IW, OH, OW, PH, PW,
ws);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, run);
}
MIDOUT_END();
}
bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2);
avaible &=
!(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
return avaible;
}
void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec(
const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto SW = param.stride[0];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, i, mode) \
MIDOUT_BEGIN( \
megdnn_arm_common_pooling, midout_iv(8), midout_iv(#type #i##_hash)) { \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + \
n * C * IH * IW * 4 + c * IH * IW * 4, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW, ws); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
#define DISPATCH_MODE(type, func, stride) \
switch (param.mode) { \
case Mode::MAX: { \
DISPATCH_FUNC(type, func, stride, max); \
break; \
} \
case Mode::AVERAGE: { \
DISPATCH_FUNC(type, func, stride, avg); \
break; \
} \
default: \
megdnn_throw( \
ssprintf( \
"Unsupport pooling mode %d", static_cast<int>(param.mode)) \
.c_str()); \
}
#define DISPATCH_STRIDE(type, func) \
switch (SW) { \
case 1: { \
DISPATCH_MODE(type, func, 1); \
break; \
} \
case 2: { \
DISPATCH_MODE(type, func, 2); \
break; \
} \
default: \
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
}
DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2);
avaible &=
!(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
return avaible;
}
void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec(
const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto SW = param.stride[0];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, i, mode) \
MIDOUT_BEGIN( \
megdnn_arm_common_pooling, midout_iv(9), midout_iv(#func #i##_hash)) { \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + \
n * C * IH * IW * 4 + c * IH * IW * 4, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW, ws); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
#define DISPATCH_MODE(type, func, stride) \
switch (param.mode) { \
case Mode::MAX: { \
DISPATCH_FUNC(type, func, stride, max); \
break; \
} \
case Mode::AVERAGE: { \
DISPATCH_FUNC(type, func, stride, avg); \
break; \
} \
default: \
megdnn_throw( \
ssprintf( \
"Unsupport pooling mode %d", static_cast<int>(param.mode)) \
.c_str()); \
}
#define DISPATCH_STRIDE(type, func) \
switch (SW) { \
case 1: { \
DISPATCH_MODE(type, func, 1); \
break; \
} \
case 2: { \
DISPATCH_MODE(type, func, 2); \
break; \
} \
default: \
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
}
DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2);
avaible &=
!(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
return avaible;
}
void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec(
const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto SW = param.stride[0];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, i, mode) \
MIDOUT_BEGIN( \
megdnn_arm_common_pooling, midout_iv(10), midout_iv(#func #i##_hash)) { \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + \
n * C * IH * IW * 4 + c * IH * IW * 4, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW, ws); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
#define DISPATCH_MODE(type, func, stride) \
switch (param.mode) { \
case Mode::MAX: { \
DISPATCH_FUNC(type, func, stride, max); \
break; \
} \
case Mode::AVERAGE: { \
DISPATCH_FUNC(type, func, stride, avg); \
break; \
} \
default: \
megdnn_throw( \
ssprintf( \
"Unsupport pooling mode %d", static_cast<int>(param.mode)) \
.c_str()); \
}
#define DISPATCH_STRIDE(type, func) \
switch (SW) { \
case 1: { \
DISPATCH_MODE(type, func, 1); \
break; \
} \
case 2: { \
DISPATCH_MODE(type, func, 2); \
break; \
} \
default: \
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
}
DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2);
avaible &=
!(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
return avaible;
}
void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec(
const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto SW = param.stride[0];
auto src_ptr = param.src_ptr;
auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, i, mode) \
MIDOUT_BEGIN( \
megdnn_arm_common_pooling, midout_iv(11), midout_iv(#func #i##_hash)) { \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr.get_ptr()) + \
n * C * IH * IW * 4 + c * IH * IW * 4, \
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW, ws); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
} \
MIDOUT_END();
#define DISPATCH_MODE(type, func, stride) \
switch (param.mode) { \
case Mode::MAX: { \
DISPATCH_FUNC(type, func, stride, max); \
break; \
} \
case Mode::AVERAGE: { \
DISPATCH_FUNC(type, func, stride, avg); \
break; \
} \
default: \
megdnn_throw( \
ssprintf( \
"Unsupport pooling mode %d", static_cast<int>(param.mode)) \
.c_str()); \
}
#define DISPATCH_STRIDE(type, func) \
switch (SW) { \
case 1: { \
DISPATCH_MODE(type, func, 1); \
break; \
} \
case 2: { \
DISPATCH_MODE(type, func, 2); \
break; \
} \
default: \
megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
}
DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}
} }