#include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_elemwise_ternary)
using namespace megdnn;
using namespace arm_common;
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::FUSE_MUL_ADD3) \
return true;
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT
#define DECL_AVAILABLE(case, type) \
bool ElemwiseImpl::AlgoTernaryFma3##case ::is_available( \
const KernParam& kern_param) const { \
if (type == kern_param.broad_cast_type) { \
auto& elparam = kern_param.ternary_elparam; \
auto& src0 = elparam[0]; \
DISPATCH_TYPE("AlgoTernaryFma3::is_available" #case##_hash); \
} \
return false; \
}
DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C);
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX);
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC);
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC);
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
#undef DECL_CB
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, FuseMulAdd3Op); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT
void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecVecVec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr())[0], \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecVecScalar::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
is_broadcasted_channel_like(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \
static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \
static_cast<const _type*>(src2.raw_ptr()) + offset, \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
is_NHWC_broadcasted_channel_like(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, size_t, const _type*, _type*, DType, \
DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST111C_VEC_BCAST111C>::run; \
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
size_t src1_offset = \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()) + \
offset * (binfo.z + src1_offset), \
src1_offset, static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
megdnn_assert(
is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST101xX_VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
megdnn_assert(
is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecBcast101xXVec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
is_broadcasted_channel_like(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecBcast101Vec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
is_NHWC_broadcasted_channel_like(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, size_t, const _type*, const _type*, size_t, _type*, \
DType, DType, DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecBcast111CVec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type, const _type*, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr())[0], \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecScalarVec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type, const _type, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr())[0], \
static_cast<const _type*>(src2.raw_ptr())[0], \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecScalarScalar::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT