#pragma once
#include "megdnn/oprs/general.h"
#include "megdnn/tensor_format.h"
#include "src/common/elemwise_helper.cuh"
#include "src/common/utils.h"
namespace megdnn {
class ElemwiseLayoutHelper {
public:
struct BroadcastChannelInfo {
size_t x, y, z;
bool operator==(const BroadcastChannelInfo& rhs) const {
return x == rhs.x && y == rhs.y && z == rhs.z;
}
};
struct Broadcast1xInfo {
size_t x, y;
bool operator==(const Broadcast1xInfo& rhs) const {
return x == rhs.x && y == rhs.y;
}
};
template <int arity>
static ElemwiseOpParamN<arity> make_elemwise_op_param(
void* opr,
void (*check_layout_and_broadcast)(
void*, const TensorLayoutPtrArray&, const TensorLayout&),
const TensorNDArray& src, const TensorND& dst);
static bool is_vector(const TensorLayout& layout) {
if (layout.format.type() != TensorFormat::Type::DEFAULT) {
return layout.is_contiguous();
}
return layout.ndim == 1 && layout.stride[0] == 1;
}
static bool is_broadcasted_1x(const TensorLayout& layout, Broadcast1xInfo& binfo);
static bool is_broadcasted_scalar(const TensorLayout& layout);
static bool is_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
static bool is_broadcasted_3dim_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
static bool is_NHWC_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
template <size_t slice_size>
static bool is_broadcastedx_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
};
class ElemwiseForwardImplHelper : public ElemwiseForward,
protected ElemwiseLayoutHelper {
static void call_check_layout_and_broadcast(
void* opr, const TensorLayoutPtrArray& src, const TensorLayout& dst) {
return static_cast<ElemwiseForwardImplHelper*>(opr)->check_layout_and_broadcast(
src, dst);
}
protected:
const TensorNDArray* m_src = nullptr;
const TensorND* m_dst = nullptr;
template <int arity>
ElemwiseOpParamN<arity> make_elemwise_op_param() {
return ElemwiseLayoutHelper::make_elemwise_op_param<arity>(
this, call_check_layout_and_broadcast, *m_src, *m_dst);
}
void prepare_fma3(ElemwiseOpParamN<3>& param, bool& c_is_scalar);
void prepare_fma4(ElemwiseOpParamN<4>& param);
public:
using ElemwiseForward::ElemwiseForward;
};
}