megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/serialization/impl/serializer.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/serialization/serializer.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/utility.h"

namespace mgb {
namespace serialization {

/* ====================== helper impls ====================== */
GraphLoader::LoadResult::~LoadResult() noexcept = default;

std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
        const ComputingGraph::OutputSpec& outspec) {
    auto ret = graph->compile(outspec);
    if (graph->options().comp_node_seq_record_level == 2) {
        ComputingGraph::assert_destroy(graph);
    }
    return ret;
}

void GraphLoader::LoadResult::graph_compile_ahead() {
    //! when force_output_use_user_specified_memory is set, the output var may
    //! be changed by gopt, then the var in LoadResult can not exist, so here
    //! just do basic optimize_for_inference ahead, and replace the var in
    //! LoadResult
    if (graph->options().force_output_use_user_specified_memory) {
        auto options = gopt::OptimizeForInferenceOptions{};
        auto new_vars = gopt::optimize_for_inference(output_var_list, options);
        output_var_list = new_vars;
        output_var_map.clear();
        for (auto& var : new_vars) {
            output_var_map[var.node()->cname()] = var;
        }
        std::unordered_map<size_t, SymbolVar> var_map_id;
        for (auto& var : new_vars) {
            bool found = false;
            for (auto& old_var_it : output_var_map_id) {
                if (old_var_it.second.node()->name() == var.node()->name()) {
                    found = true;
                    var_map_id[old_var_it.first] = var;
                }
            }
            mgb_assert(
                    found, "can't find var name %s when optimize_for_inference. ",
                    var.node()->cname());
        }
    }
}

GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
    SharedTensorNameMap ret;
    for (auto&& i : shared_tensor_id_map()) {
        mgb_assert(!i.first.empty(), "name stripped during graph dump");
        auto ins = ret.emplace(i.first, &i.second);
        mgb_assert(ins.second);
    }
    return ret;
}
std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
bool is_fbs_file(InputFile& file);

bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
#if MGB_ENABLE_GRAD
    return opr->same_type<opr::SetGrad>();
#else
    return false;
#endif
}

std::unique_ptr<GraphDumper> GraphDumper::make(
        std::unique_ptr<OutputFile> file, GraphDumpFormat format) {
    switch (format) {
        case GraphDumpFormat::FLATBUFFERS:
#if MGB_ENABLE_FBS_SERIALIZATION
            return make_fbs_dumper(std::move(file));
#endif
            MGB_FALLTHRU
        default:
            mgb_throw(SerializationError, "unsupported serialization format requested");
    }
    mgb_assert(false, "unreachable");
}

std::unique_ptr<GraphLoader> GraphLoader::make(
        std::unique_ptr<InputFile> file, GraphDumpFormat format) {
    switch (format) {
        case GraphDumpFormat::FLATBUFFERS:
#if MGB_ENABLE_FBS_SERIALIZATION
            return make_fbs_loader(std::move(file));
#endif
            MGB_FALLTHRU
        default:
            mgb_throw(SerializationError, "unsupported serialization format requested");
    }
    mgb_assert(false, "unreachable");
}

Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) {
#if MGB_ENABLE_FBS_SERIALIZATION
    if (is_fbs_file(file)) {
        return GraphDumpFormat::FLATBUFFERS;
    }
#endif
    return {};
}

}  // namespace serialization
}  // namespace mgb