#include "./opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT)
namespace megdnn {
namespace fallback {
void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
if (!dst.layout.is_contiguous()) {
return naive::ElemwiseForwardImpl::exec(srcs, dst);
}
m_src = &srcs;
m_dst = &dst;
#define CONCAT2(a, b, c) a##_##b##_##c
#define CONCAT(a, b, c) CONCAT2(a, b, c)
#define SWITCH_MODE_CB(_mode) \
case Mode::_mode: \
MIDOUT_BEGIN( \
CONCAT(megdnn_fallback_elemwise_exec, ARITY, CAT), \
midout_iv(Mode::_mode)) { \
return CONCAT(exec, ARITY, CAT)<param_enumv::Elemwise::Mode::_mode>(); \
} \
MIDOUT_END();
#define SWITCH_MODE \
switch (m_param.mode) { \
CONCAT(MEGDNN_FOREACH_ELEMWISE_MODE, ARITY, CAT) \
(SWITCH_MODE_CB) default : megdnn_throw("bad mode"); \
}
if (dst.layout.dtype.category() == DTypeCategory::INT) {
#define CAT INT
if (srcs.size() == 1) {
#define ARITY UNARY
SWITCH_MODE
#undef ARITY
}
if (srcs.size() == 2) {
#define ARITY BINARY
SWITCH_MODE
#undef ARITY
}
#undef CAT
} else if (dst.layout.dtype.category() == DTypeCategory::FLOAT) {
#define CAT FLOAT
if (srcs.size() == 1) {
#define ARITY UNARY
SWITCH_MODE
#undef ARITY
}
if (srcs.size() == 2) {
#define ARITY BINARY
SWITCH_MODE
#undef ARITY
}
#undef CAT
}
#undef cb
naive::ElemwiseForwardImpl::exec(srcs, dst);
}
} }