#include "./opr_impl_helper.h"
#include "src/common/utils.h"
using namespace megdnn;
template <int arity>
ElemwiseOpParamN<arity> ElemwiseLayoutHelper::make_elemwise_op_param(
void* opr,
void (*check_layout_and_broadcast)(
void*, const TensorLayoutPtrArray&, const TensorLayout&),
const TensorNDArray& src, const TensorND& dst) {
megdnn_assert(src.size() == static_cast<size_t>(arity));
ElemwiseOpParamN<arity> ret;
TensorLayoutPtrArray src_layouts(arity);
for (int i = 0; i < arity; ++i) {
ret.param[i] = src[i];
src_layouts[i] = &ret.param[i].layout;
}
check_layout_and_broadcast(opr, src_layouts, dst.layout);
ret.init_from_given_tensor();
return ret;
}
#define INST(n) \
template ElemwiseOpParamN<n> ElemwiseLayoutHelper::make_elemwise_op_param<n>( \
void*, void (*)(void*, const TensorLayoutPtrArray&, const TensorLayout&), \
const TensorNDArray&, const TensorND&)
INST(1);
INST(2);
INST(3);
INST(4);
INST(5);
INST(6);
#undef INST
void ElemwiseForwardImplHelper::prepare_fma3(
ElemwiseOpParamN<3>& param, bool& c_is_scalar) {
c_is_scalar = is_broadcasted_scalar(m_src->at(2).layout);
param = make_elemwise_op_param<3>();
if (!c_is_scalar && !param[2].layout.eq_layout(param[0].layout)) {
megdnn_assert_eq_layout(param[2].layout, param[1].layout);
std::swap(param[0], param[1]);
}
if (c_is_scalar && param[2].layout.eq_layout(param[0].layout)) {
std::swap(param[0], param[1]);
}
}
void ElemwiseForwardImplHelper::prepare_fma4(ElemwiseOpParamN<4>& param) {
param = make_elemwise_op_param<4>();
if (!param[0].layout.eq_layout(param[2].layout))
std::swap(param[0], param[1]);
megdnn_assert_eq_layout(param[0].layout, param[2].layout);
megdnn_assert_eq_layout(param[1].layout, param[3].layout);
}
bool ElemwiseLayoutHelper::is_broadcasted_scalar(const TensorLayout& layout) {
if (layout.format.type() != TensorFormat::Type::DEFAULT)
return false;
for (size_t i = 0; i < layout.ndim; ++i) {
if (layout.shape[i] != 1 && layout.stride[i] != 0)
return false;
}
return true;
}
template <size_t slice_size>
bool ElemwiseLayoutHelper::is_broadcastedx_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT && layout.ndim == 3 &&
layout.stride[0] == slice_size && layout.stride[1] == 0 &&
layout.stride[2] == 1) {
info.x = layout.shape[0];
info.y = layout.shape[1];
info.z = layout.shape[2];
return true;
} else if (
layout.format.type() == TensorFormat::Type::DEFAULT && layout.ndim == 4 &&
layout.stride[0] == 0 && layout.stride[1] == slice_size &&
layout.stride[2] == 0 && layout.stride[3] == 1) {
info.x = layout.shape[1];
info.y = layout.shape[2];
info.z = layout.shape[3];
return true;
}
return false;
}
#define INST(n) \
template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like<n>( \
const TensorLayout& layout, BroadcastChannelInfo& info)
INST(4);
INST(8);
#undef INST
bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {
if (layout.ndim == 3 && layout.stride[0] == 0 && layout.stride[2] == 0 &&
layout.stride[1] == 1) {
info.x = layout.shape[0];
info.y = layout.shape[1];
info.z = layout.shape[2];
return true;
} else if (layout.ndim == 2 && layout.stride[1] == 0 && layout.stride[0] == 1) {
info.x = 1;
info.y = layout.shape[0];
info.z = layout.shape[1];
return true;
}
} else {
if (Image2DPack4TensorFormat::is_valid_image(layout)) {
auto align_axis =
layout.format.as_impl<Image2DPack4TensorFormat>().align_axis();
if (layout.ndim == 4 && align_axis == 1 &&
(layout.stride[0] == 0 || layout.shape[0] == 1) &&
layout.stride[1] == 4 && layout.stride[2] == 0 &&
layout.stride[3] == 1) {
info.x = 1;
info.y = 1;
info.z = layout.shape[2];
return true;
} else if (
layout.ndim == 3 && align_axis == 1 &&
(layout.stride[0] == 0 || layout.shape[0] == 1) &&
layout.stride[1] == 0 && layout.shape[2] == 4 &&
layout.stride[2] == 1) {
info.x = 1;
info.y = 1;
info.z = layout.shape[1];
return true;
}
return false;
}
}
return false;
}
bool ElemwiseLayoutHelper::is_broadcasted_3dim_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {
if (layout.ndim == 3 && (layout.stride[0] - layout.shape[2]) == 0 &&
layout.stride[1] == 0 && layout.stride[2] == 1) {
info.x = layout.shape[0];
info.y = layout.shape[1];
info.z = layout.shape[2];
return true;
}
}
return false;
}
bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {
if (layout.ndim == 2 && layout.stride[1] == 1 && layout.stride[0] == 0) {
info.x = 1;
info.y = layout.shape[0];
info.z = layout.shape[1];
return true;
}
}
return false;
}
bool ElemwiseLayoutHelper::is_broadcasted_1x(
const TensorLayout& layout, Broadcast1xInfo& binfo) {
if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) {
binfo.x = layout[0];
binfo.y = layout[1];
return true;
}
if (layout.ndim == 1 && layout.stride[0] == 1) {
binfo.x = 1;
binfo.y = layout[0];
return true;
}
return false;
}