megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/custom/test/param.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain_build_config.h"

#if MGB_CUSTOM_OP

#include <iostream>
#include "gtest/gtest.h"
#include "megbrain/custom/param.h"

#define PARAM_TEST_LOG 0

namespace custom {

#define SchemaDef                                                               \
    ParamSchema schema_bool("param_bool", true, "bool");                        \
    ParamSchema schema_flt("param_flt", 2.3f, "float");                         \
    ParamSchema schema_int("param_int", 4, "int");                              \
    ParamSchema schema_str("param_str", "test", "string");                      \
    ParamSchema schema_bool_list("param_bl", {true, false, true}, "bool list"); \
    ParamSchema schema_flt_list("param_fl", {1.1f, 2.2f, 3.3f}, "float list");  \
    ParamSchema schema_int_list("param_il", {1, 2, 3}, "int list");             \
    ParamSchema schema_str_list("param_sl", {"test1", "test2", "test3"}, "string list")

#define InfoDef                                 \
    info.meta().emplace_back(schema_bool);      \
    info.meta().emplace_back(schema_flt);       \
    info.meta().emplace_back(schema_int);       \
    info.meta().emplace_back(schema_str);       \
    info.meta().emplace_back(schema_bool_list); \
    info.meta().emplace_back(schema_flt_list);  \
    info.meta().emplace_back(schema_int_list);  \
    info.meta().emplace_back(schema_str_list)

TEST(TestParam, TestParamScheme) {
#if PARAM_TEST_LOG
    SchemaDef;
    ParamSchema new_schema = schema_int;

    std::cout << schema_bool.str() << std::endl;
    std::cout << schema_flt.str() << std::endl;
    std::cout << schema_int.str() << std::endl;
    std::cout << schema_str.str() << std::endl;
    std::cout << schema_bool_list.str()
              << "len: " << schema_bool_list.default_val().size() << std::endl;
    std::cout << schema_flt_list.str()
              << "len: " << schema_flt_list.default_val().size() << std::endl;
    std::cout << schema_int_list.str()
              << "len: " << schema_int_list.default_val().size() << std::endl;
    std::cout << schema_str_list.str()
              << "len: " << schema_str_list.default_val().size() << std::endl;

    std::cout << new_schema.str() << std::endl;
#endif
}

TEST(TestParam, TestParamVal) {
    ParamVal pv1 = 1.2f, pv2 = true, pv3 = "test", pv4 = {0, 1, 2},
             pv5 = {true, false, true};

#if PARAM_TEST_LOG
    ParamVal pv6 = {"test1", "test2", "test3"};
    std::cout << pv1.str() << std::endl;
    std::cout << pv2.str() << std::endl;
    std::cout << pv3.str() << std::endl;
    std::cout << pv4.str() << std::endl;
    std::cout << pv5.str() << std::endl;
    std::cout << pv6.str() << std::endl;
#endif

    ParamVal pv_manip = pv1;
    ASSERT_TRUE(pv_manip.type() == pv1.type());
    ASSERT_TRUE(pv_manip == pv1);
    pv_manip = 1.3;
    ASSERT_TRUE(pv_manip.type() != pv1.type());
    ASSERT_TRUE(pv_manip != pv1);
    ASSERT_TRUE(pv_manip > pv1);
    pv_manip = pv_manip + pv1;
    ASSERT_TRUE(pv_manip.type() == ParamDynType::Float64);
    ASSERT_TRUE(pv_manip == 1.3 + 1.2f);
    pv_manip = 1.3f + 1.2f;
    ASSERT_TRUE(pv_manip.type() == pv1.type());

    pv_manip = false;
    ASSERT_TRUE(pv_manip.type() == pv2.type());
    ASSERT_TRUE(pv_manip.type() == ParamDynType::Bool);
    ASSERT_TRUE(pv_manip != pv2);

    pv_manip = "test";
    ASSERT_TRUE(pv_manip.type() == pv3.type());
    ASSERT_TRUE(pv_manip.type() == ParamDynType::String);
    ASSERT_TRUE(pv_manip == pv3);
    pv_manip = "test1";
    ASSERT_TRUE(pv_manip > pv3);
    pv_manip = pv_manip + pv3;
    ASSERT_TRUE(pv_manip == "test1test");

    pv_manip = {0, 1, 2};
    ASSERT_TRUE(pv_manip.type() == pv4.type());
    ASSERT_TRUE(pv_manip.type() == ParamDynType::Int32List);
    ASSERT_TRUE(pv_manip == pv4);
    pv_manip = {3, 2, 1};
    ASSERT_TRUE(pv_manip != pv4);
    ASSERT_TRUE(pv_manip > pv4);

    pv_manip = {true, false, true};
    ASSERT_TRUE(pv_manip.type() == pv5.type());
    ASSERT_TRUE(pv_manip.type() == ParamDynType::BoolList);
    ASSERT_TRUE(pv_manip == pv5);
    pv_manip = {false, true, false};
    ASSERT_TRUE(pv_manip != pv5);
}

TEST(TestParam, TestParamInfo) {
    ParamInfo info;
    info.set_tag("Test");
#if PARAM_TEST_LOG
    uint32_t tag = info.tag();
    std::cout << tag << std::endl;
#endif

    SchemaDef;
    InfoDef;

    ParamInfo new_info1, new_info2;
    new_info1.set_meta(info.meta());
    new_info2.meta() = info.meta();

#if PARAM_TEST_LOG
    for (auto ele : new_info1.meta()) {
        std::cout << ele.str() << std::endl;
    }
    for (auto ele : new_info2.meta()) {
        std::cout << ele.str() << std::endl;
    }
#endif
}

TEST(TestParam, TestParam) {
    ParamInfo info;
    SchemaDef;
    InfoDef;

    Param param(info);

#if PARAM_TEST_LOG
    std::vector<std::string> names = {"param_bool", "param_flt", "param_int",
                                      "param_str",  "param_bl",  "param_fl",
                                      "param_il",   "param_sl"};
    for (auto& name : names) {
        std::cout << param[name].str() << std::endl;
        ;
    }
#endif
    ASSERT_TRUE(param["param_bool"] == true);
    ASSERT_TRUE(param["param_flt"] == 2.3f);
    ASSERT_TRUE(param["param_int"] == 4);
    ASSERT_TRUE(param["param_str"] == "test");
    ASSERT_TRUE(param["param_bl"] == ParamVal({true, false, true}));
    ASSERT_TRUE(param["param_fl"] == ParamVal({1.1f, 2.2f, 3.3f}));
    ASSERT_TRUE(param["param_il"] == ParamVal({1, 2, 3}));
    ASSERT_TRUE(param["param_sl"] == ParamVal({"test1", "test2", "test3"}));

    param["param_bool"] = false;
    param["param_flt"] = 3.4f;
    param["param_int"] = 5;
    param["param_str"] = "tset";
    param["param_bl"] = {false, true, false, true};
    param["param_fl"] = {7.6f, 6.5f};
    param["param_il"] = {5, 4, 3, 2, 1};
    param["param_sl"] = {"1tset", "2tset", "3tset", "4tset", "5tset"};

    ASSERT_TRUE(param["param_bool"] != true);
    ASSERT_TRUE(param["param_flt"] != 2.3f);
    ASSERT_TRUE(param["param_int"] != 4);
    ASSERT_TRUE(param["param_str"] != "test");
    ASSERT_TRUE(param["param_bl"] != ParamVal({true, false, true}));
    ASSERT_TRUE(param["param_fl"] != ParamVal({1.1f, 2.2f, 3.3f}));
    ASSERT_TRUE(param["param_il"] != ParamVal({1, 2, 3}));
    ASSERT_TRUE(param["param_sl"] != ParamVal({"test1", "test2", "test3"}));

    ASSERT_TRUE(param["param_bool"] == false);
    ASSERT_TRUE(param["param_flt"] == 3.4f);
    ASSERT_TRUE(param["param_int"] == 5);
    ASSERT_TRUE(param["param_str"] == "tset");
    ASSERT_TRUE(param["param_bl"] == ParamVal({false, true, false, true}));
    ASSERT_TRUE(param["param_fl"] == ParamVal({7.6f, 6.5f}));
    ASSERT_TRUE(param["param_il"] == ParamVal({5, 4, 3, 2, 1}));
    ASSERT_TRUE(
            param["param_sl"] ==
            ParamVal({"1tset", "2tset", "3tset", "4tset", "5tset"}));

#if PARAM_TEST_LOG
    Param copy_param = param;
    for (auto& name : names) {
        std::cout << copy_param[name].str() << std::endl;
    }
#endif

    Param loaded_param(info);
    std::string bytes = param.to_bytes();
    loaded_param.from_bytes(bytes);

#if PARAM_TEST_LOG
    for (auto& kv : loaded_param.raw()) {
        std::cout << kv.first << ":\n" << kv.second.str() << std::endl;
    }
#endif
}

}  // namespace custom

#endif