#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
#include "../src/decryption/decrypt_base.h"
#include "../src/network_impl_base.h"
#include "test_common.h"
#include "megbrain/opr/io.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/metahelper.h"
#include <gtest/gtest.h>
#include <string.h>
#include <chrono>
#include <memory>
#include <random>
using namespace lite;
TEST(TestMisc, DecryptionRegister) {
size_t number = decryption_static_data().decryption_methods.size();
ASSERT_GE(number, 1);
DecryptionFunc func;
register_decryption_and_key("AllForTest0", func, {});
ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
}
TEST(TestMisc, DecryptionUpdate) {
DecryptionFunc func;
register_decryption_and_key("AllForTest1", func, {});
func = [](const void*, size_t,
const std::vector<uint8_t>&) -> std::vector<uint8_t> { return {}; };
update_decryption_or_key("AllForTest1", func, {});
ASSERT_NE(
decryption_static_data().decryption_methods["AllForTest1"].first, nullptr);
ASSERT_EQ(
decryption_static_data().decryption_methods["AllForTest1"].second->size(),
0);
update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
ASSERT_EQ(
decryption_static_data().decryption_methods["AllForTest1"].second->size(),
3);
}
TEST(TestMisc, SharedSameDeviceTensor) {
using namespace mgb;
serialization::GraphLoader::LoadConfig mgb_config;
mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
};
mgb_config.comp_graph = ComputingGraph::make();
std::string model_path = "./shufflenet.mge";
auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
auto format = serialization::GraphLoader::identify_graph_dump_format(*inp_file);
mgb_assert(
format.valid(),
"invalid model: unknown model format, please make sure input "
"file is generated by GraphDumper");
auto loader = serialization::GraphLoader::make(std::move(inp_file), format.val());
auto load_ret_1 = loader->load(mgb_config, true);
auto load_ret_2 = loader->load(mgb_config, true);
ASSERT_EQ(load_ret_1.output_var_list.size(), load_ret_2.output_var_list.size());
ComputingGraph::OutputSpec out_spec_1, out_spec_2;
for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
}
auto func_1 = load_ret_1.graph_compile(out_spec_1);
auto func_2 = load_ret_2.graph_compile(out_spec_1);
std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
if (opr->try_cast_final<opr::ImmutableTensor>()) {
oprs_1.push_back(opr);
}
return true;
});
func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
if (opr->try_cast_final<opr::ImmutableTensor>()) {
oprs_2.push_back(opr);
}
return true;
});
ASSERT_EQ(oprs_1.size(), oprs_2.size());
for (size_t i = 0; i < oprs_1.size(); i++) {
auto tensor_1 = oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
auto tensor_2 = oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
}
}
#endif