megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/opr/impl/internal/megdnn_opr_wrapper.inl
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

namespace mgb {
namespace opr {

namespace intl {
/*!
 * \brief template that can be specialized so inputs of an operator could be
 *      modified in-place
 *
 * Invoked by MEGDNN_OPR_INIT* macros
 *
 * \tparam Opr an megbrain opr final class
 */
template <class Opr>
struct MegDNNOprInitInputsModifier {
    static inline void apply(
            const typename Opr::Param& param,
            std::initializer_list<SymbolVar*> inputs) {
        MGB_MARK_USED_VAR(param);
        MGB_MARK_USED_VAR(inputs);
    }
};

/*!
 * \brief template that can be specialized to be called in opr constructor
 *
 * Invoked by MEGDNN_OPR_INIT* macros
 */
template <class Opr>
struct MegDNNOprInitPostCtor {
    static inline void apply(cg::OperatorNodeBase& opr) { MGB_MARK_USED_VAR(opr); }
};

//! get megdnn Workspace object from a workspace var
megdnn::Workspace get_megdnn_workspace_from_var(VarNode* var);

/*!
 * \brief A UserData object associated with the computing graph to get
 *      maximal usable workspace.
 *
 * It works by first limit workspace to 0 and alloc to get free memory, and
 * assume workspace can use all free memory.
 * It would produce a var node, which should be taken as a value dep for
 * workspace static infer functors so memory manager can re-allocate.
 */
class WorkspaceLimitGetter {
    class Impl;
    static Impl* get_impl(ComputingGraph* graph);

public:
    /*!
     * \brief get usable workspace size in bytes for a comp node
     *
     * Can only be called after is_prealloc_run() returns false
     *
     * \param old_limit workspace limit set by user, which would be an
     *      upper bound for the return value
     */
    static size_t get_workspace_limit(
            ComputingGraph* graph, CompNode cn, size_t old_limit);

    //! return whether current is pre-allocation so workspace should
    //! return 0
    static bool is_prealloc_run(ComputingGraph* graph);

    /*!
     * \brief register WorkspaceLimitGetter in a graph
     * \return an var to be added as extra value dep for workspace
     *      infer; it would be null if WorkspaceLimitGetter is disabled
     *      at compile time
     */
    static VarNode* register_to_graph(ComputingGraph* graph);
};

/*!
 * a template that can be specialized to indicate whether
 * WorkspaceLimitGetter is needed for an operator class
 *
 * \tparam MegDNNOpr a megdnn opr class
 */
template <class MegDNNOpr>
struct AutoAddWorkspaceNeedLimitGetter {
    static constexpr bool val = false;
};

/*!
 * \brief implement megdnn::DynOutMallocPolicy using memory management
 *      system in megbrain
 */
class MegDNNDynOutMallocImpl final : public megdnn::DynOutMallocPolicy {
    cg::OperatorNodeBase* m_opr;
    CompNode m_cn;

public:
    MegDNNDynOutMallocImpl(cg::OperatorNodeBase* opr, CompNode cn)
            : m_opr{opr}, m_cn{cn} {}

    megdnn::TensorND alloc_output(
            size_t id, DType dtype, const TensorShape& shape, void* user_data) override;

    void* alloc_workspace(size_t sz, void* user_data) override;
    void free_workspace(void* ptr, void* user_data) override;
};

/* ======================= MegDNNOprMethInvoker ======================= */
namespace {

template <int nr_in, int nr_out>
struct _MegDNNOprMethInvoker;

template <class Opr>
using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPUTS>;

#define _NR_INPUTS          1
#define _NR_OUTPUTS         1
#define _FOREACH_IO(_i, _o) _i(0), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          1
#define _NR_OUTPUTS         2
#define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          1
#define _NR_OUTPUTS         3
#define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          2
#define _NR_OUTPUTS         1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          2
#define _NR_OUTPUTS         2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          3
#define _NR_OUTPUTS         1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          3
#define _NR_OUTPUTS         2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          3
#define _NR_OUTPUTS         3
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          4
#define _NR_OUTPUTS         1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          4
#define _NR_OUTPUTS         4
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          5
#define _NR_OUTPUTS         1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          5
#define _NR_OUTPUTS         2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          5
#define _NR_OUTPUTS         3
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          6
#define _NR_OUTPUTS         1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS          6
#define _NR_OUTPUTS         2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS  6
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
    _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS  7
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
    _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS  9
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o)                                                      \
    _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _i(8), _o(0), _o(1), \
            _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
}  // anonymous namespace

/* ======================= MegDNNOprWrapperFwd ======================= */

template <class MegDNNOpr>
void MegDNNOprWrapperFwd<MegDNNOpr>::init_output_static_infer_desc() {
    Super::set_nr_managed_outputs(this->output().size() - 1);
    Super::init_output_static_infer_desc();
    this->init_output_static_infer_desc_workspace(
            AutoAddWorkspaceNeedLimitGetter<MegDNNOpr>::val);
}

template <class MegDNNOpr>
void MegDNNOprWrapperFwd<MegDNNOpr>::scn_do_execute() {
    MegDNNOprMethInvoker<MegDNNOpr>::exec(this->megdnn_opr(), this);
}

template <class MegDNNOpr>
size_t MegDNNOprWrapperFwd<MegDNNOpr>::get_workspace_size_bytes(
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
    return this->mixin_get_workspace_size_bytes_by_megdnn(
            *this, input_shapes, output_shapes);
}

template <class MegDNNOpr>
void MegDNNOprWrapperFwd<MegDNNOpr>::get_output_var_shape(
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
    MegDNNOprMethInvoker<MegDNNOpr>::deduce_layout(
            this->megdnn_opr(), this, inp_shape, out_shape);
}

/* ======================= MegDNNOprWrapperBwd ======================= */

template <class MegDNNOpr>
void MegDNNOprWrapperBwd<MegDNNOpr>::init_output_static_infer_desc() {
    this->mixin_init_output_static_infer_desc_bwd(*this);
    this->init_output_static_infer_desc_workspace(
            AutoAddWorkspaceNeedLimitGetter<MegDNNOpr>::val);
}

template <class MegDNNOpr>
void MegDNNOprWrapperBwd<MegDNNOpr>::scn_do_execute() {
    MegDNNOprMethInvoker<MegDNNOpr>::exec(this->megdnn_opr(), this);
}

template <class MegDNNOpr>
size_t MegDNNOprWrapperBwd<MegDNNOpr>::get_workspace_size_bytes(
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
    return this->mixin_get_workspace_size_bytes_by_megdnn(
            *this, input_shapes, output_shapes);
}

template <class MegDNNOpr>
typename MegDNNOprWrapperBwd<MegDNNOpr>::Super::NodeProp* MegDNNOprWrapperBwd<
        MegDNNOpr>::do_make_node_prop() const {
    auto prop = Super::do_make_node_prop();
    this->mixin_update_node_prop(*this, prop);
    return prop;
}

}  // namespace intl

namespace mixin {
/* ======================= MegDNNOprHolderImpl ======================= */

template <class MegDNNOpr, bool add_workspace, class OprHolder>
size_t MegDNNOprHolderImpl<MegDNNOpr, add_workspace, OprHolder>::
        mixin_get_workspace_size_bytes_by_megdnn(
                const OperatorNodeBase& opr, const TensorShapeArray& input_shapes,
                const TensorShapeArray& output_shapes) const {
    static_assert(add_workspace, "must add_workspace");
    return intl::MegDNNOprMethInvoker<MegDNNOpr>::get_workspace_in_bytes(
            this->megdnn_opr(), &opr, input_shapes, output_shapes);
}
}  // namespace mixin

}  // namespace opr
}  // namespace mgb

//! generate opr constructor, with 1 arg
#define MEGDNN_OPR_CTOR_INIT1(_name, _node_name, ...)                               \
    _name::_name(VarNode* i0, const Param& param, const OperatorNodeConfig& config) \
            : Super(                                                                \
                      OperatorNodeBaseCtorParam{                                    \
                              i0->owner_graph(), config, _node_name, {i0}},         \
                      ##__VA_ARGS__) {                                              \
        init_megdnn_opr(*this, param);                                              \
        add_input({i0});                                                            \
        intl::MegDNNOprInitPostCtor<_name>::apply(*this);                           \
    }
//! generate opr constructor and ::make, with 1 arg
#define MEGDNN_OPR_INIT1(_name, _node_name, ...)                                  \
    MEGDNN_OPR_CTOR_INIT1(_name, _node_name, ##__VA_ARGS__)                       \
    SymbolVar _name::make(                                                        \
            SymbolVar i0, const Param& param, const OperatorNodeConfig& config) { \
        intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0});            \
        return i0.insert_single_output_opr<_name>(i0.node(), param, config);      \
    }

//! generate opr constructor, with 2 args
#define MEGDNN_OPR_CTOR_INIT2(_name, _node_name, ...)                       \
    _name::_name(                                                           \
            VarNode* i0, VarNode* i1, const Param& param,                   \
            const OperatorNodeConfig& config)                               \
            : Super(                                                        \
                      OperatorNodeBaseCtorParam{                            \
                              i0->owner_graph(), config, _node_name, {i0}}, \
                      ##__VA_ARGS__) {                                      \
        init_megdnn_opr(*this, param);                                      \
        add_input({i0, i1});                                                \
        intl::MegDNNOprInitPostCtor<_name>::apply(*this);                   \
    }
//! generate opr constructor and ::make, with 2 args
#define MEGDNN_OPR_INIT2(_name, _node_name, ...)                            \
    MEGDNN_OPR_CTOR_INIT2(_name, _node_name, ##__VA_ARGS__)                 \
    SymbolVar _name::make(                                                  \
            SymbolVar i0, SymbolVar i1, const Param& param,                 \
            const OperatorNodeConfig& config) {                             \
        intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1}); \
        return i0.insert_single_output_opr<_name>(                          \
                i0.node(), i1.node(), param, config);                       \
    }

//! generate opr constructor, with 3 args
#define MEGDNN_OPR_CTOR_INIT3(_name, _node_name, ...)                       \
    _name::_name(                                                           \
            VarNode* i0, VarNode* i1, VarNode* i2, const Param& param,      \
            const OperatorNodeConfig& config)                               \
            : Super(                                                        \
                      OperatorNodeBaseCtorParam{                            \
                              i0->owner_graph(), config, _node_name, {i0}}, \
                      ##__VA_ARGS__) {                                      \
        init_megdnn_opr(*this, param);                                      \
        add_input({i0, i1, i2});                                            \
        intl::MegDNNOprInitPostCtor<_name>::apply(*this);                   \
    }
//! generate opr constructor and ::make, with 3 args
#define MEGDNN_OPR_INIT3(_name, _node_name, ...)                                 \
    MEGDNN_OPR_CTOR_INIT3(_name, _node_name, ##__VA_ARGS__)                      \
    SymbolVar _name::make(                                                       \
            SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param,        \
            const OperatorNodeConfig& config) {                                  \
        intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1, &i2}); \
        return i0.insert_single_output_opr<_name>(                               \
                i0.node(), i1.node(), i2.node(), param, config);                 \
    }

//! generate opr constructor, with 4 args
#define MEGDNN_OPR_CTOR_INIT4(_name, _node_name, ...)                               \
    _name::_name(                                                                   \
            VarNode* i0, VarNode* i1, VarNode* i2, VarNode* i3, const Param& param, \
            const OperatorNodeConfig& config)                                       \
            : Super(                                                                \
                      OperatorNodeBaseCtorParam{                                    \
                              i0->owner_graph(), config, _node_name, {i0}},         \
                      ##__VA_ARGS__) {                                              \
        init_megdnn_opr(*this, param);                                              \
        add_input({i0, i1, i2, i3});                                                \
        intl::MegDNNOprInitPostCtor<_name>::apply(*this);                           \
    }
//! generate opr constructor and ::make, with 4 args
#define MEGDNN_OPR_INIT4(_name, _node_name, ...)                                      \
    MEGDNN_OPR_CTOR_INIT4(_name, _node_name, ##__VA_ARGS__)                           \
    SymbolVar _name::make(                                                            \
            SymbolVar i0, SymbolVar i1, SymbolVar i2, SymbolVar i3,                   \
            const Param& param, const OperatorNodeConfig& config) {                   \
        intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1, &i2, &i3}); \
        return i0.insert_single_output_opr<_name>(                                    \
                i0.node(), i1.node(), i2.node(), i3.node(), param, config);           \
    }

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}