megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file imperative/src/impl/ops/opr_attr.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/serialization/opr_load_dump.h"

#include "../op_trait.h"
#include "megbrain/imperative/proxy_graph_detail.h"

namespace mgb {
namespace imperative {

namespace {
class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
    const OprAttr::Param& m_param;
    size_t m_pos = 0;
    ComputingGraph* m_graph;

    void read_raw(void* dest, size_t size) override final {
        mgb_assert(m_pos + size <= m_param.size(), "too many bytes requested");
        memcpy(dest, m_param.data() + m_pos, size);
        m_pos += size;
    }

    std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }

    std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); }

    const serialization::GraphLoadConfig& config() const override { mgb_assert(0); }

public:
    OprParamsLoadContext(const OprAttr::Param& param, ComputingGraph* graph)
            : serialization::OprLoadContextRawPOD(false),
              m_param(param),
              m_graph(graph) {}

    ~OprParamsLoadContext() {
        mgb_assert(m_pos == m_param.size(), "param not fully consumed");
    }

    ComputingGraph& graph() override { return *m_graph; }
};

class OprParamsDumpContext final : public serialization::OprDumpContextRawPOD {
public:
    OprAttr::Param m_param;
    OprParamsDumpContext() : serialization::OprDumpContextRawPOD(false) {}
    void write_raw(const void* data, size_t size) {
        const char* src = static_cast<const char*>(data);
        m_param.insert(m_param.end(), src, src + size);
    }
    void dump_tensor(
            const std::string& name, const HostTensorND& tensor,
            TensorWriteMethod method) {
        mgb_assert(0);
    }
    const serialization::GraphDumpConfig& config() const { mgb_assert(0); }
};

VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& attr = def.cast_final_safe<OprAttr>();
    auto config = attr.config;
    config.name(attr.make_name());
    mgb_assert(!inputs.empty());
    auto registry = serialization::OprRegistry::find_by_name(attr.type);
    mgb_assert(registry, "operator %s not found", attr.type.c_str());
    OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
    return registry->loader(ctx, inputs, config).usable_output();
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
    OprParamsDumpContext ctx;
    auto registry = serialization::OprRegistry::find_by_type(opr->dyn_typeinfo());
    mgb_assert(registry, "operator %s not found", opr->dyn_typeinfo()->name);
    mgb_assert(
            registry->dumper, "operator %s cannot be serialized",
            opr->dyn_typeinfo()->name);
    registry->dumper(ctx, *opr);
    return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config());
}

std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
    return {};
}

std::string make_name(const OpDef& def) {
    auto&& attr = def.cast_final_safe<OprAttr>();
    return attr.type;
}

OP_TRAIT_REG(OprAttr, OprAttr)
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .props(props)
        .make_name(make_name)
        .fallback();

}  // anonymous namespace

bool OprAttr::is_same_st(const Hashable& rhs_) const {
    auto&& rhs = static_cast<const OprAttr&>(rhs_);
    return type == rhs.type && param == rhs.param &&
           config.comp_node() == rhs.config.comp_node() &&
           config.output_dtype() == rhs.config.output_dtype();
}

size_t OprAttr::hash() const {
    return hash_pair_combine(
            hash_pair_combine(
                    mgb::hash(type), mgb::hash(static_cast<std::vector<char>>(param))),
            config.hash());
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}