megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/test/common/topk.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "test/common/topk.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs/general.h"
#include "test/common/checker.h"

using namespace megdnn;
using namespace test;

namespace {
class EqualValueRng final : public RNG {
    std::mt19937_64 m_rng{23};

public:
    void gen(const TensorND& tensor) override {
        memset(tensor.raw_ptr(), 0, tensor.layout.span().dist_byte());
        ASSERT_EQ(2u, tensor.layout.ndim);
        size_t m = tensor.layout[0], n = tensor.layout[1];
        for (size_t i = 0; i < m; ++i) {
            int pos0 = m_rng() % n, pos1;
            do {
                pos1 = m_rng() % n;
            } while (pos0 == pos1);

            pos0 += i * n;
            pos1 += i * n;

#define CASE(ev, dt)                             \
    case DTypeEnum::ev: {                        \
        auto p = tensor.ptr<dt>();               \
        p[pos0] = p[pos1] = static_cast<dt>(-1); \
        break;                                   \
    }

            switch (tensor.layout.dtype.enumv()) {
                CASE(Float32, float);
                CASE(Int32, int);
                DNN_INC_FLOAT16(CASE(Float16, half_float::half));
                default:
                    megdnn_throw("bad dtype");
            }
        }
#undef CASE
    }
};
}  // namespace

template <typename Dtype>
void test::run_topk_test(Handle* handle) {
    Checker<TopK> checker{handle};
    using Mode = TopK::Param::Mode;

    bool tie_breaking_mode = false;
    Mode cur_mode;
    auto output_canonizer = [&](const CheckerHelper::TensorValueArray& arr) {
        if (cur_mode == Mode::KTH_ONLY) {
            return;
        }
        auto pinp = arr[0].ptr<typename DTypeTrait<Dtype>::ctype>();
        auto pval = arr[1].ptr<typename DTypeTrait<Dtype>::ctype>();
        auto pidx = arr.at(2).ptr<int>();
        size_t m = arr[1].layout[0], n = arr[1].layout[1];
        using idx_val = std::pair<int, typename DTypeTrait<Dtype>::ctype>;
        std::vector<idx_val> data(n);
        auto compare = [](const idx_val& it1, const idx_val& it2) {
            return (it1.second > it2.second);
        };
        for (size_t i = 0; i < m; ++i) {
            if (cur_mode == Mode::VALUE_IDX_NOSORT) {
                // sort output pairs to canonize
                for (size_t j = 0; j < n; ++j) {
                    data[j].first = pidx[i * n + j];
                    data[j].second = pval[i * n + j];
                }
                std::sort(data.begin(), data.end(), compare);
                for (size_t j = 0; j < n; ++j) {
                    pidx[i * n + j] = data[j].first;
                    pval[i * n + j] = data[j].second;
                }
            }
            if (tie_breaking_mode) {
                // check if indices are correct and mark all indices to be zero
                for (size_t j = 0; j < n; ++j) {
                    auto idx = pidx[i * n + j];
                    auto val = pval[i * n + j];
                    // + 0 can change the type, such as changing half to float
                    ASSERT_EQ(pinp[i * arr[0].layout[1] + idx] + 0, val + 0);
                    pidx[i * n + j] = 0;
                }
            }
        }
    };

    auto run = [&](int k, size_t m, size_t n, Mode mode, int lda = 0) {
        if (::testing::Test::HasFailure()) {
            return;
        }
        cur_mode = mode;
        checker.set_proxy(k);
        checker.set_param(mode);
        TensorLayout layout{{m, n}, Dtype{}};
        if (lda) {
            layout.stride[0] = lda;
        }

        checker.set_output_canonizer(output_canonizer);

        if (mode == Mode::KTH_ONLY) {
            checker.execl({layout, {}});
        } else {
            checker.execl({layout, {}, {}});
        }
        if (!checker.prev_succ()) {
            fprintf(stderr, "topk failed for (%zu,%zu):%d mode=%d cont=%d tie=%d\n", m,
                    n, k, static_cast<int>(mode), !lda, tie_breaking_mode);
            return;
        }
    };

    std::unique_ptr<IIDRNG> rng0;
    std::unique_ptr<RNG> rngf16;
    std::unique_ptr<NoReplacementRNG> rng1;
    switch (DTypeTrait<Dtype>::enumv) {
        case DTypeEnum::Float32: {
            rng0 = std::make_unique<UniformFloatRNG>(-100.f, 100.f);
            rng1 = std::make_unique<NoReplacementRNG>(rng0.get());
            checker.set_rng(0, rng1.get());
            break;
        }
        case DTypeEnum::Int32: {
            rng0 = std::make_unique<UniformIntRNG>(INT_MIN, INT_MAX);
            rng1 = std::make_unique<NoReplacementRNG>(rng0.get());
            checker.set_rng(0, rng1.get());
            break;
        }
        case DTypeEnum::Float16: {
            rngf16 = std::make_unique<Float16PeriodicalRNG>();
            checker.set_rng(0, rngf16.get());
            break;
        }
        default: {
            megdnn_throw(
                    ssprintf("only float32,int32 and float16 supported for "
                             "cuda and opencl topk"));
        }
    }

    for (auto mode : {Mode::KTH_ONLY, Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
        run(1, 1, 1, mode);
        run(-1, 1, 1, mode);
        run(1, 23, 1, mode);
        run(1, 23, 100, mode);
        run(-1, 23, 100, mode);
        run(5, 23, 100, mode);
        run(-7, 23, 100, mode);
        run(23, 3, 50001, mode);
        run(5, 123, 3, mode);         // equiv to sort
        run(-5, 123, 3, mode);        // equiv to rev sort
        run(5, 3, 1231, mode, 2000);  // non contig

//! opencl does not support large batch. fix it in the future.
#if MGB_CUDA
        run(3, 70000, 5, mode, 10);  // non contig
#endif
    }

    // special case to check if tie-break is correct
    auto tie_rng = std::make_unique<EqualValueRng>();
    tie_breaking_mode = true;
    checker.set_rng(0, tie_rng.get());
    for (auto mode : {Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
        run(3, 1, 5, mode);
        run(3, 25, 4567, mode);
        run(8, 132, 10, mode);
    }
}
namespace megdnn {
namespace test {
#define INST(t) template void run_topk_test<t>(Handle*)

INST(dtype::Float32);
INST(dtype::Int32);
DNN_INC_FLOAT16(INST(dtype::Float16));
#undef INST
}  // namespace test
}  // namespace megdnn

// vim: syntax=cpp.doxygen