#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
namespace mgb {
namespace opr {
namespace intl {
template <class Opr>
struct MegDNNOprInitInputsModifier {
static inline void apply(
const typename Opr::Param& param,
std::initializer_list<SymbolVar*> inputs) {
MGB_MARK_USED_VAR(param);
MGB_MARK_USED_VAR(inputs);
}
};
template <class Opr>
struct MegDNNOprInitPostCtor {
static inline void apply(cg::OperatorNodeBase& opr) { MGB_MARK_USED_VAR(opr); }
};
megdnn::Workspace get_megdnn_workspace_from_var(VarNode* var);
class WorkspaceLimitGetter {
class Impl;
static Impl* get_impl(ComputingGraph* graph);
public:
static size_t get_workspace_limit(
ComputingGraph* graph, CompNode cn, size_t old_limit);
static bool is_prealloc_run(ComputingGraph* graph);
static VarNode* register_to_graph(ComputingGraph* graph);
};
template <class MegDNNOpr>
struct AutoAddWorkspaceNeedLimitGetter {
static constexpr bool val = false;
};
class MegDNNDynOutMallocImpl final : public megdnn::DynOutMallocPolicy {
cg::OperatorNodeBase* m_opr;
CompNode m_cn;
public:
MegDNNDynOutMallocImpl(cg::OperatorNodeBase* opr, CompNode cn)
: m_opr{opr}, m_cn{cn} {}
megdnn::TensorND alloc_output(
size_t id, DType dtype, const TensorShape& shape, void* user_data) override;
void* alloc_workspace(size_t sz, void* user_data) override;
void free_workspace(void* ptr, void* user_data) override;
};
namespace {
template <int nr_in, int nr_out>
struct _MegDNNOprMethInvoker;
template <class Opr>
using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPUTS>;
#define _NR_INPUTS 1
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 1
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 1
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 2
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 2
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 3
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 3
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 3
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 4
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 4
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 5
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 5
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 5
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 6
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 6
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 6
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 7
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 9
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _i(8), _o(0), _o(1), \
_o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
}
template <class MegDNNOpr>
void MegDNNOprWrapperFwd<MegDNNOpr>::init_output_static_infer_desc() {
Super::set_nr_managed_outputs(this->output().size() - 1);
Super::init_output_static_infer_desc();
this->init_output_static_infer_desc_workspace(
AutoAddWorkspaceNeedLimitGetter<MegDNNOpr>::val);
}
template <class MegDNNOpr>
void MegDNNOprWrapperFwd<MegDNNOpr>::scn_do_execute() {
MegDNNOprMethInvoker<MegDNNOpr>::exec(this->megdnn_opr(), this);
}
template <class MegDNNOpr>
size_t MegDNNOprWrapperFwd<MegDNNOpr>::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return this->mixin_get_workspace_size_bytes_by_megdnn(
*this, input_shapes, output_shapes);
}
template <class MegDNNOpr>
void MegDNNOprWrapperFwd<MegDNNOpr>::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
MegDNNOprMethInvoker<MegDNNOpr>::deduce_layout(
this->megdnn_opr(), this, inp_shape, out_shape);
}
template <class MegDNNOpr>
void MegDNNOprWrapperBwd<MegDNNOpr>::init_output_static_infer_desc() {
this->mixin_init_output_static_infer_desc_bwd(*this);
this->init_output_static_infer_desc_workspace(
AutoAddWorkspaceNeedLimitGetter<MegDNNOpr>::val);
}
template <class MegDNNOpr>
void MegDNNOprWrapperBwd<MegDNNOpr>::scn_do_execute() {
MegDNNOprMethInvoker<MegDNNOpr>::exec(this->megdnn_opr(), this);
}
template <class MegDNNOpr>
size_t MegDNNOprWrapperBwd<MegDNNOpr>::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return this->mixin_get_workspace_size_bytes_by_megdnn(
*this, input_shapes, output_shapes);
}
template <class MegDNNOpr>
typename MegDNNOprWrapperBwd<MegDNNOpr>::Super::NodeProp* MegDNNOprWrapperBwd<
MegDNNOpr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
this->mixin_update_node_prop(*this, prop);
return prop;
}
}
namespace mixin {
template <class MegDNNOpr, bool add_workspace, class OprHolder>
size_t MegDNNOprHolderImpl<MegDNNOpr, add_workspace, OprHolder>::
mixin_get_workspace_size_bytes_by_megdnn(
const OperatorNodeBase& opr, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
static_assert(add_workspace, "must add_workspace");
return intl::MegDNNOprMethInvoker<MegDNNOpr>::get_workspace_in_bytes(
this->megdnn_opr(), &opr, input_shapes, output_shapes);
}
}
} }
#define MEGDNN_OPR_CTOR_INIT1(_name, _node_name, ...) \
_name::_name(VarNode* i0, const Param& param, const OperatorNodeConfig& config) \
: Super( \
OperatorNodeBaseCtorParam{ \
i0->owner_graph(), config, _node_name, {i0}}, \
##__VA_ARGS__) { \
init_megdnn_opr(*this, param); \
add_input({i0}); \
intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
}
#define MEGDNN_OPR_INIT1(_name, _node_name, ...) \
MEGDNN_OPR_CTOR_INIT1(_name, _node_name, ##__VA_ARGS__) \
SymbolVar _name::make( \
SymbolVar i0, const Param& param, const OperatorNodeConfig& config) { \
intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0}); \
return i0.insert_single_output_opr<_name>(i0.node(), param, config); \
}
#define MEGDNN_OPR_CTOR_INIT2(_name, _node_name, ...) \
_name::_name( \
VarNode* i0, VarNode* i1, const Param& param, \
const OperatorNodeConfig& config) \
: Super( \
OperatorNodeBaseCtorParam{ \
i0->owner_graph(), config, _node_name, {i0}}, \
##__VA_ARGS__) { \
init_megdnn_opr(*this, param); \
add_input({i0, i1}); \
intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
}
#define MEGDNN_OPR_INIT2(_name, _node_name, ...) \
MEGDNN_OPR_CTOR_INIT2(_name, _node_name, ##__VA_ARGS__) \
SymbolVar _name::make( \
SymbolVar i0, SymbolVar i1, const Param& param, \
const OperatorNodeConfig& config) { \
intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1}); \
return i0.insert_single_output_opr<_name>( \
i0.node(), i1.node(), param, config); \
}
#define MEGDNN_OPR_CTOR_INIT3(_name, _node_name, ...) \
_name::_name( \
VarNode* i0, VarNode* i1, VarNode* i2, const Param& param, \
const OperatorNodeConfig& config) \
: Super( \
OperatorNodeBaseCtorParam{ \
i0->owner_graph(), config, _node_name, {i0}}, \
##__VA_ARGS__) { \
init_megdnn_opr(*this, param); \
add_input({i0, i1, i2}); \
intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
}
#define MEGDNN_OPR_INIT3(_name, _node_name, ...) \
MEGDNN_OPR_CTOR_INIT3(_name, _node_name, ##__VA_ARGS__) \
SymbolVar _name::make( \
SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param, \
const OperatorNodeConfig& config) { \
intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1, &i2}); \
return i0.insert_single_output_opr<_name>( \
i0.node(), i1.node(), i2.node(), param, config); \
}
#define MEGDNN_OPR_CTOR_INIT4(_name, _node_name, ...) \
_name::_name( \
VarNode* i0, VarNode* i1, VarNode* i2, VarNode* i3, const Param& param, \
const OperatorNodeConfig& config) \
: Super( \
OperatorNodeBaseCtorParam{ \
i0->owner_graph(), config, _node_name, {i0}}, \
##__VA_ARGS__) { \
init_megdnn_opr(*this, param); \
add_input({i0, i1, i2, i3}); \
intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
}
#define MEGDNN_OPR_INIT4(_name, _node_name, ...) \
MEGDNN_OPR_CTOR_INIT4(_name, _node_name, ##__VA_ARGS__) \
SymbolVar _name::make( \
SymbolVar i0, SymbolVar i1, SymbolVar i2, SymbolVar i3, \
const Param& param, const OperatorNodeConfig& config) { \
intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1, &i2, &i3}); \
return i0.insert_single_output_opr<_name>( \
i0.node(), i1.node(), i2.node(), i3.node(), param, config); \
}