#include "test/common/elemwise.h"
#include "src/common/utils.cuh"
#include "test/common/checker.h"
#include "test/common/utils.h"
#include "megdnn/oprs/general.h"
#include "test/common/fix_gtest_on_platforms_without_exception.inl"
using namespace megdnn;
using namespace test;
namespace {
void fma3_extra_opr_impl(const TensorNDArray& data) {
megdnn_assert(data.size() == 4);
auto handle = create_cpu_handle(2);
auto opr = handle->create_operator<Elemwise>();
using Mode = Elemwise::Mode;
opr->param().mode = Mode::MUL;
opr->exec({data[0], data[1]}, data[3]);
opr->param().mode = Mode::ADD;
opr->exec({data[2], data[3]}, data[3]);
}
void fma4_extra_opr_impl(const TensorNDArray& data) {
megdnn_assert(data.size() == 5);
std::vector<uint8_t> tmp_storage(data[4].layout.span().dist_byte());
TensorND tmp;
tmp.reset_ptr(tmp_storage.data());
tmp.layout = data[4].layout;
tmp.layout.init_contiguous_stride();
auto handle = create_cpu_handle(2);
auto opr = handle->create_operator<Elemwise>();
using Mode = Elemwise::Mode;
opr->param().mode = Mode::MUL;
opr->exec({data[0], data[1]}, data[4]);
opr->exec({data[2], data[3]}, tmp);
opr->param().mode = Mode::ADD;
opr->exec({tmp, data[4]}, data[4]);
}
TensorLayout make_layout(
const TensorShape& shp, std::initializer_list<ptrdiff_t> stride) {
TensorLayout ret{shp, dtype::Float32()};
megdnn_assert(stride.size() == shp.ndim);
auto idx = 0;
for (auto i : stride)
ret.stride[idx++] = i;
return ret;
}
}
namespace megdnn {
namespace test {
namespace elemwise {
#define DEF_TEST(name) \
template <> \
void run_test<name>(Handle * handle)
DEF_TEST(unary) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::SIN);
checker.set_dtype(0, dtype::Float32()).execs({{3, 4, 1}, {}});
checker.set_dtype(0, dtype::Float16()).execs({{3, 4, 1}, {}});
}
DEF_TEST(binary_brdcst) {
auto run = [&](DType dtype) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::ADD);
checker.set_dtype(0, dtype);
checker.set_dtype(1, dtype);
checker.execs({{3, 1}, {1, 3}, {3, 3}});
{
checker.execs({{10, 11}, {10, 11}, {10, 11}});
checker.execs({{2, 3, 4, 5, 6, 7}, {1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}});
checker.execs({{1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}, {2, 3, 4, 5, 6, 7}});
checker.execs({{256, 256, 3}, {1, 1, 3}, {256, 256, 3}});
checker.execs({{1, 1, 3}, {256, 256, 3}, {256, 256, 3}});
checker.execs({{8, 1, 6, 1}, {1, 7, 1, 5}, {8, 7, 6, 5}});
checker.execs({{1, 7, 1, 5}, {8, 1, 6, 1}, {8, 7, 6, 5}});
checker.execs({{5, 4}, {1, 1}, {5, 4}});
checker.execs({{1, 1}, {5, 4}, {5, 4}});
checker.execs({{5, 4}, {1, 4}, {5, 4}});
checker.execs({{1, 4}, {5, 4}, {5, 4}});
checker.execs({{15, 3, 5}, {15, 1, 5}, {15, 3, 5}});
checker.execs({{15, 1, 5}, {15, 3, 5}, {15, 3, 5}});
checker.execs({{15, 3, 5}, {1, 3, 5}, {15, 3, 5}});
checker.execs({{1, 3, 5}, {15, 3, 5}, {15, 3, 5}});
checker.execs({{15, 3, 5}, {1, 3, 1}, {15, 3, 5}});
checker.execs({{1, 3, 1}, {15, 3, 5}, {15, 3, 5}});
checker.execs({{3, 1}, {1, 4}, {3, 4}});
checker.execs({{2, 3, 1, 5}, {4, 5}, {2, 3, 4, 5}});
checker.execs({{3, 1, 1}, {4, 5}, {3, 4, 5}});
}
{
{
auto n = 1000u;
checker.execs({{n}, {n}, {n}});
checker.execs({{1}, {n}, {n}});
checker.execs({{n}, {1}, {n}});
}
{
auto m = 200u, n = 100u;
auto collapse = [](size_t n, bool is_collapsed) {
return is_collapsed ? 1u : n;
};
for (auto msk = 0u; msk < 16; ++msk) {
checker.execs(
{{collapse(m, msk & 1), collapse(n, msk & 2)},
{collapse(m, msk & 4), collapse(n, msk & 8)},
{}});
}
}
{
checker.execs({{2, 3, 4, 5, 6}, {1, 3, 1, 5, 6}, {2, 3, 4, 5, 6}});
checker.execs({{2, 3, 4, 5, 6}, {2, 1, 4, 1, 6}, {2, 3, 4, 5, 6}});
}
}
};
run(dtype::Float32());
}
DEF_TEST(binary_non_contig) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::ADD);
TensorLayout ly{{2, 3}, dtype::Float32()};
ly.stride[0] = 4;
checker.execl({ly, ly, {{2, 3}, dtype::Float32()}});
}
DEF_TEST(ternary) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::COND_LEQ_MOV);
checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
checker.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(3, dtype::Float16())
.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
}
DEF_TEST(ternary_non_contig) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::COND_LEQ_MOV);
TensorLayout ly{{2, 3}, dtype::Float32()};
ly.stride[0] = 4;
checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
}
DEF_TEST(fuse_mul_add3) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
const TensorShape& s2) {
TensorShape dest;
dest.ndim = s0.ndim;
for (size_t i = 0; i < dest.ndim; ++i) {
auto a = i < s0.ndim ? s0[i] : 1;
auto b = i < s1.ndim ? s1[i] : 1;
dest[i] = std::max(a, b);
}
return TensorShapeArray{s0, s1, s2, dest};
};
checker.exec(make_shape({2, 1}, {2, 2}, {2, 2}));
checker.exec(make_shape({2, 2}, {2, 1}, {2, 2}));
checker.exec(make_shape({2, 2}, {2, 2}, {1}));
checker.exec(make_shape({3, 1}, {1, 3}, {3, 1}));
checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}, {1}));
checker.exec(make_shape({1, 1, 3}, {5, 8, 1}, {5, 8, 1}));
checker.exec(make_shape({1, 192, 9, 16}, {1}, {1, 192, 9, 16}));
}
DEF_TEST(fuse_mul_add3_non_contig) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
TensorLayout ly{{2, 3}, dtype::Float32()};
ly.stride[0] = 4;
checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
}
DEF_TEST(fuse_mul_add4) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::FUSE_MUL_ADD4).set_extra_opr_impl(fma4_extra_opr_impl);
auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
bool swap = false) {
TensorShape dest;
dest.ndim = s0.ndim;
for (size_t i = 0; i < dest.ndim; ++i) {
auto a = i < s0.ndim ? s0[i] : 1;
auto b = i < s1.ndim ? s1[i] : 1;
dest[i] = std::max(a, b);
}
TensorShapeArray ret{s0, s1, s0, s1, dest};
if (swap)
std::swap(ret[2], ret[3]);
return ret;
};
checker.exec(make_shape({2, 2}, {2, 2}));
checker.exec(make_shape({3, 1}, {1, 3}));
checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}));
checker.exec(make_shape({4, 2}, {1, 2}, true));
}
DEF_TEST(rmulh) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
auto run_for_dtype = [&checker](auto dtype) {
auto minv = DTypeTrait<decltype(dtype)>::min();
auto maxv = DTypeTrait<decltype(dtype)>::max();
UniformIntRNG rng0{minv, maxv};
UniformIntRNG rngM{(maxv >> 1) + 1, maxv};
checker.set_param({Mode::RMULH})
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_rng(0, &rng0)
.set_rng(1, &rngM);
checker.execs({{7, 9, 11, 13}, {1}, {}})
.execs({{16, 3, 256, 256}, {1}, {}})
.execs({{2, 3, 1, 7}, {2, 3, 1, 7}, {}})
.execs({{9, 5, 4}, {1, 5, 1}, {}})
.execs({{233}, {1}, {}});
};
run_for_dtype(dtype::Int8());
run_for_dtype(dtype::Int16());
run_for_dtype(dtype::Int32());
}
#define UNARY_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {}});
#define BUILD_UNARY_TEST_CASE_INT \
UNARY_TEST_CASE(RELU) \
UNARY_TEST_CASE(ABS)
#define BUILD_UNARY_TEST_CASE_FLOAT \
UNARY_TEST_CASE(ABS) \
UNARY_TEST_CASE(LOG) \
UNARY_TEST_CASE(COS) \
UNARY_TEST_CASE(SIN) \
UNARY_TEST_CASE(FLOOR) \
UNARY_TEST_CASE(CEIL) \
UNARY_TEST_CASE(SIGMOID) \
UNARY_TEST_CASE(EXP) \
UNARY_TEST_CASE(TANH) \
UNARY_TEST_CASE(FAST_TANH) \
UNARY_TEST_CASE(RELU) \
UNARY_TEST_CASE(ROUND)
DEF_TEST(unary1) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_dtype(0, dtype::Int8());
BUILD_UNARY_TEST_CASE_INT
checker.set_dtype(0, dtype::Int16());
BUILD_UNARY_TEST_CASE_INT
checker.set_dtype(0, dtype::Int32());
BUILD_UNARY_TEST_CASE_INT
UniformFloatRNG rng(1e-2, 6e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
BUILD_UNARY_TEST_CASE_FLOAT
}
#undef UNARY_TEST_CASE
#undef BUILD_UNARY_TEST_CASE_INT
#undef BUILD_UNARY_TEST_CASE_FLOAT
#define BINARY_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
#define BUILD_BINARY_TEST_CASE \
BINARY_TEST_CASE(MIN) \
BINARY_TEST_CASE(MAX)
#define BINARY_COMPLATE_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
#define BUILD_BINARY_COMPLATE_TEST_CASE \
BINARY_COMPLATE_TEST_CASE(ADD) \
BINARY_COMPLATE_TEST_CASE(MUL) \
BINARY_COMPLATE_TEST_CASE(MAX) \
BINARY_COMPLATE_TEST_CASE(MIN) \
BINARY_COMPLATE_TEST_CASE(SUB)
#define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
BINARY_COMPLATE_TEST_CASE(POW) \
BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU) \
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_H_SWISH) \
BINARY_COMPLATE_TEST_CASE(FAST_TANH_GRAD) \
BINARY_COMPLATE_TEST_CASE(H_SWISH_GRAD)
DEF_TEST(binary1) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
BUILD_BINARY_COMPLATE_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
BUILD_BINARY_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
BUILD_BINARY_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
BUILD_BINARY_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE
}
#undef BINARY_TEST_CASE
#undef BUILD_BINARY_TEST_CASE
#undef BINARY_COMPLATE_TEST_CASE
#undef BUILD_BINARY_COMPLATE_TEST_CASE
#undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
#define TERNARY_COMPLATE_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
#define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
DEF_TEST(ternary1) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int8());
BUILD_TERNARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
checker.set_dtype(2, dtype::Int16());
BUILD_TERNARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
checker.set_dtype(2, dtype::Int32());
BUILD_TERNARY_COMPLATE_TEST_CASE
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
BUILD_TERNARY_COMPLATE_TEST_CASE
}
#undef TERNARY_COMPLATE_TEST_CASE
#undef BUILD_TERNARY_COMPLATE_TEST_CASE
#define UNARY_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{1, 129}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {}});
#define BUILD_UNARY_TEST_CASE_INT \
UNARY_TEST_CASE(RELU) \
UNARY_TEST_CASE(ABS) \
UNARY_TEST_CASE(NEGATE)
#define BUILD_UNARY_TEST_CASE_FLOAT \
BUILD_UNARY_TEST_CASE_INT \
UNARY_TEST_CASE(SIGMOID) \
UNARY_TEST_CASE(EXP) \
UNARY_TEST_CASE(TANH) \
UNARY_TEST_CASE(FAST_TANH) \
UNARY_TEST_CASE(H_SWISH)
DEF_TEST(unary2) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_dtype(0, dtype::Int8());
BUILD_UNARY_TEST_CASE_INT
checker.set_dtype(0, dtype::Int16());
BUILD_UNARY_TEST_CASE_INT
checker.set_dtype(0, dtype::Int32());
BUILD_UNARY_TEST_CASE_INT
{
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
BUILD_UNARY_TEST_CASE_FLOAT
}
{
UniformFloatRNG rng(1e-2, 1e1);
checker.set_rng(0, &rng);
checker.set_epsilon(6e-3);
checker.set_dtype(0, dtype::Float16());
BUILD_UNARY_TEST_CASE_FLOAT
}
{
UniformFloatRNG rng(100, 200);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_param(Mode::TANH).execs({{1, 1025}, {}});
checker.set_param(Mode::TANH).execs({{1, 7}, {}});
}
}
#undef UNARY_TEST_CASE
#undef BUILD_UNARY_TEST_CASE_INT
#undef BUILD_UNARY_TEST_CASE_FLOAT
#define BINARY_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
#define BUILD_BINARY_TEST_CASE \
BINARY_TEST_CASE(MIN) \
BINARY_TEST_CASE(MAX)
#define BINARY_COMPLATE_TEST_CASE(_optr) \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
#define BUILD_BINARY_COMPLATE_TEST_CASE \
BINARY_COMPLATE_TEST_CASE(ADD) \
BINARY_COMPLATE_TEST_CASE(MUL) \
BINARY_COMPLATE_TEST_CASE(MAX) \
BINARY_COMPLATE_TEST_CASE(MIN) \
BINARY_COMPLATE_TEST_CASE(SUB) \
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
DEF_TEST(binary2) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
BUILD_BINARY_COMPLATE_TEST_CASE
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID)
BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH)
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
BUILD_BINARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
BUILD_BINARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
BUILD_BINARY_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
BUILD_BINARY_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE
BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
{
UniformFloatRNG rng(1e-3, 3e1);
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_epsilon(1e-3);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
BUILD_BINARY_TEST_CASE
BUILD_BINARY_COMPLATE_TEST_CASE
BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
}
}
#undef BINARY_TEST_CASE
#undef BUILD_BINARY_TEST_CASE
#undef BINARY_COMPLATE_TEST_CASE
#undef BUILD_BINARY_COMPLATE_TEST_CASE
#define TERNARY_COMPLATE_TEST_CASE(_optr) \
checker.set_param(Mode::_optr) \
.execs({{1, 123, 1}, {300, 123, 253}, {300, 123, 253}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 1}, {1, 1, 1}, {3, 4, 1}, {}});
#define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
DEF_TEST(ternary2) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int8());
BUILD_TERNARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
checker.set_dtype(2, dtype::Int16());
BUILD_TERNARY_COMPLATE_TEST_CASE
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
checker.set_dtype(2, dtype::Int32());
BUILD_TERNARY_COMPLATE_TEST_CASE
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
BUILD_TERNARY_COMPLATE_TEST_CASE
{
UniformFloatRNG rng(1e-3, 3e1);
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.set_epsilon(1e-3);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
checker.set_dtype(2, dtype::Float16());
BUILD_TERNARY_COMPLATE_TEST_CASE
}
}
#undef TERNARY_COMPLATE_TEST_CASE
#undef BUILD_TERNARY_COMPLATE_TEST_CASE
DEF_TEST(unary3) {
Checker<Elemwise> checker(handle);
auto make_layouts =
[](const TensorShape& shp,
std::initializer_list<ptrdiff_t> stride) -> TensorLayoutArray {
return {make_layout(shp, stride), {shp, dtype::Float32()}};
};
checker.set_param({Elemwise::Mode::SIN});
checker.exec(make_layouts({2, 2}, {2, 1}));
checker.exec(make_layouts({4}, {3}));
}
DEF_TEST(binary3) {
Checker<Elemwise> checker(handle);
checker.set_param({Elemwise::Mode::ADD});
auto run = [&](const TensorShape& shp0, std::initializer_list<ptrdiff_t> stride0,
const TensorShape& shp1, std::initializer_list<ptrdiff_t> stride1) {
TensorShape shpo;
Elemwise::deduce_shape({shp0, shp1}, shpo);
auto ly0 = make_layout(shp0, stride0), ly1 = make_layout(shp1, stride1),
lyo = TensorLayout{shpo, dtype::Float32()};
checker.execl({ly0, ly1, lyo});
checker.execl({ly1, ly0, lyo});
};
run({2, 2}, {2, 1}, {2, 2}, {2, 1});
run({1}, {1}, {3, 3}, {1, 2});
run({3, 4, 5}, {40, 10, 2}, {1, 4, 1}, {1, 1, 1});
}
DEF_TEST(all_modes) {
Checker<Elemwise> checker(handle);
TensorShapeArray shapes;
UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
tanh_rng_f32{-5.f, 5.f};
UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
big_nonzero_rng_f32{100.f, 1000.f};
UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
shift_rng_i32_i32{0, 31}, shift_rng_i32_i8{0, 7};
UniformIntNonZeroRNG nonzero_rng_i32{1, 100};
using Mode = Elemwise::Mode;
auto should_ignore = [handle](Mode mode) {
MEGDNN_MARK_USED_VAR(mode);
return false;
};
for (int mode_nr = 0; mode_nr < static_cast<int>(Elemwise::Param::MODE_NR_MEMBER);
++mode_nr) {
auto mode = static_cast<Mode>(mode_nr);
if (should_ignore(mode)) {
continue;
}
checker.set_param({mode});
auto&& trait = Elemwise::ModeTrait::from_mode(mode);
shapes.resize(trait.arity + 1);
for (size_t i = 0; i < shapes.size() - 1; ++i) {
shapes[i] = {3, 9, 7};
}
auto do_run = [&](DType dtype, float eps = 1e-3) {
if (mode == Mode::LOG || mode == Mode::LOG1P) {
checker.set_rng(0, &pos_rng_f32);
} else if (mode == Mode::POW) {
checker.set_rng(0, &small_pos_rng_f32);
checker.set_rng(1, &small_rng_f32);
} else if (mode == Mode::EXP || mode == Mode::EXPM1) {
checker.set_rng(0, &small_rng_f32);
} else if (mode == Mode::FAST_TANH) {
checker.set_rng(0, &tanh_rng_f32);
} else if (mode == Mode::LOG_SUM_EXP) {
checker.set_rng(0, &big_nonzero_rng_f32);
checker.set_rng(1, &big_nonzero_rng_f32);
} else if (
mode == Mode::ASIN || mode == Mode::ACOS ||
mode == Mode::SIGMOID_GRAD || mode == Mode::TANH_GRAD ||
mode == Mode::ERFINV) {
checker.set_rng(0, &abslt1_rng_f32);
checker.set_rng(1, &default_rng_f32);
} else if (mode == Mode::ERFCINV) {
checker.set_rng(0, &uniform_0_2_rng);
} else if (
mode == Mode::MOD || mode == Mode::TRUE_DIV ||
mode == Mode::FLOOR_DIV) {
if (dtype.category() == DTypeCategory::INT) {
checker.set_rng(0, &default_rng_i32);
checker.set_rng(1, &nonzero_rng_i32);
} else {
checker.set_rng(0, &default_rng_f32);
checker.set_rng(1, &nonzero_rng_f32);
}
} else if (mode == Mode::EQ) {
checker.set_rng(0, &small_rng_i32);
checker.set_rng(1, &small_rng_i32);
} else if (mode == Mode::SHL || mode == Mode::SHR) {
checker.set_rng(0, &default_rng_i32);
if (dtype.size() == 4) {
checker.set_rng(1, &shift_rng_i32_i32);
} else {
megdnn_assert(dtype.size() == 1);
checker.set_rng(1, &shift_rng_i32_i8);
}
} else if (mode == Mode::ATAN2) {
checker.set_rng(0, &nonzero_rng_f32);
checker.set_rng(1, &nonzero_rng_f32);
} else {
RNG* rng;
if (dtype.category() == DTypeCategory::INT) {
rng = &default_rng_i32;
} else {
rng = &default_rng_f32;
}
for (size_t i = 0; i < shapes.size(); ++i) {
checker.set_rng(i, rng);
}
}
checker.set_epsilon(eps);
for (size_t i = 0; i < shapes.size(); ++i) {
checker.set_dtype(i, dtype);
}
EXPECT_NO_THROW(checker.execs(shapes));
if (!::testing::Test::HasFailure() && shapes.size() == 3) {
shapes[1][0] = 1;
shapes[1][2] = 1;
EXPECT_NO_THROW(checker.execs(shapes));
if (!::testing::Test::HasFailure()) {
shapes[1][1] = 1;
EXPECT_NO_THROW(checker.execs(shapes));
}
}
if (::testing::Test::HasFailure()) {
printf("failed on mode=%d(%s) dtype=%s\n", mode_nr, trait.name,
dtype.name());
for (auto&& i : shapes) {
printf("ishape: %s\n", i.to_string().c_str());
}
return false;
}
return true;
};
#define run(args...) \
do { \
if (!do_run(args)) { \
return; \
} \
} while (0)
if (trait.allow_int) {
run(dtype::Int8{});
run(dtype::Int32{});
}
if (trait.allow_float) {
DNN_FLOAT16_SELECT(
run(dtype::Float16{}, mode == Mode::FAST_TANH_GRAD ? 0.5 : 0.05), );
run(dtype::Float32{});
}
}
#undef run
}
#define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
#define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
DEF_TEST(unary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
UniformFloatRNG rng(1e-2, 6e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
}
#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
#define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
{{1, 4, 1}, dtype::Int8()}, \
{}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
{{1, 4, 1}, dtype::Int16()}, \
{}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
{{1, 4, 1}, dtype::Int32()}, \
{}});
#define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
{{1, 4, 1}, dtype::Float32()}, \
{}});
#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
DEF_TEST(binary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
UniformFloatRNG rng(1e-2, 2e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
}
#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
DEF_TEST(ternary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::FUSE_MUL_ADD3);
checker.execl(
{{{1, 7}, {-7, -1}, dtype::Int8()},
{{1, 7}, {-3, -1}, dtype::Int8()},
{{1, 7}, {-7, -1}, dtype::Int8()},
{}});
checker.execl(
{{{1, 7}, {-7, -1}, dtype::Int16()},
{{1, 7}, {-3, -1}, dtype::Int16()},
{{1, 7}, {-7, -1}, dtype::Int16()},
{}});
checker.execl(
{{{1, 7}, {-7, -1}, dtype::Int32()},
{{1, 7}, {-3, -1}, dtype::Int32()},
{{1, 7}, {-7, -1}, dtype::Int32()},
{}});
UniformFloatRNG rng(1e-2, 2e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.execl(
{{{1, 7}, {-7, -1}, dtype::Float32()},
{{1, 7}, {-3, -1}, dtype::Float32()},
{{1, 7}, {-7, -1}, dtype::Float32()},
{}});
}
TEST(TEST_ELEMWISE, MODE_TRAIT) {
using M = Elemwise::Mode;
using T = Elemwise::ModeTrait;
ASSERT_EQ(1u, T::from_mode(M::RELU).arity);
ASSERT_EQ(2u, T::from_mode(M::ADD).arity);
ASSERT_EQ(3u, T::from_mode(M::FUSE_MUL_ADD3).arity);
ASSERT_EQ(4u, T::from_mode(M::FUSE_MUL_ADD4).arity);
ASSERT_TRUE(T::from_mode(M::ADD).commutable);
ASSERT_FALSE(T::from_mode(M::TRUE_DIV).commutable);
ASSERT_TRUE(T::from_mode(M::ADD).allow_int);
ASSERT_FALSE(T::from_mode(M::EXP).allow_int);
ASSERT_TRUE(T::from_mode(M::ADD).allow_float);
ASSERT_FALSE(T::from_mode(M::SHL).allow_float);
ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);
ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
}
} } }