megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/argsort/argsort.cpp.hip
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "hcc_detail/hcc_defs_prologue.h"

#include "src/rocm/utils.h.hip"
#include "./argsort.h.hip"
#include "./bitonic_sort.h.hip"
#include "megdnn/basic_types.h"

#include "hipcub/device/device_radix_sort.hpp"
#include "hipcub/device/device_segmented_radix_sort.hpp"

using namespace megdnn;
using namespace rocm;

namespace {
struct StridedOffsetIterator {
    int bias, stride;

    StridedOffsetIterator(int bias_, int stride_)
            : bias(bias_), stride(stride_) {}

    __device__ __forceinline__ int operator[](int i) const {
        return stride * i + bias;
    }
};

bool use_bitonic(uint32_t /*M*/, uint32_t N) {
    // bitonic sort is preferred when N is small (alwyas faster than radix sort)
    return N <= BITONIC_SORT_MAX_LENGTH;
}

bool use_segmented(uint32_t M, uint32_t /*N*/) {
    // an empirical value:
    // sort(1, 1e6): 0.574ms
    // segsort({1,2,8,16}, 1e6): 7-8ms
    // sort(1, 1e7): 3.425ms
    // segsort({1,2,8,16}, 1e7): 71-84ms
    //
    // segsort is about 7x-10x slower than sort on small batches, so we can
    // expect it to be faster than sort when batch is large enough.
    return M >= 8;
}

__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) {
    uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
    if (i < n) {
        dst[i] = i % mod;
    }
}

template <typename ctype>
size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) {
    if (use_bitonic(M, N)) {
        return 0;
    }
    return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL,
                                 M, N, 0, sizeof(float)*8, NULL);
}
}  // anonymous namespace

template <typename KeyType, typename ValueType>
MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(
        bool is_ascending, void* workspace, size_t workspace_size,
        const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in,
        ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){
    hipError_t err;
    if (use_segmented(M, N)) {
        if (is_ascending) {
            err = hipcub::DeviceSegmentedRadixSort::SortPairs(
                    workspace, workspace_size, keys_in, keys_out, values_in,
                    values_out, N * M, M, StridedOffsetIterator(0, N),
                    StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
            hip_check(err);
        } else {
            err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending(
                    workspace, workspace_size, keys_in, keys_out, values_in,
                    values_out, N * M, M, StridedOffsetIterator(0, N),
                    StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
            hip_check(err);
        }
    } else {
        if (is_ascending) {
            for (size_t i = 0; i < M; ++i) {
                err = hipcub::DeviceRadixSort::SortPairs(
                        workspace, workspace_size, keys_in + N * i,
                        keys_out + N * i, values_in + N * i, values_out + N * i,
                        N, begin_bit, end_bit, stream);
                hip_check(err);
                if (!keys_in) {
                    return workspace_size;
                }
            }
        } else {
            for (size_t i = 0; i < M; ++i) {
                err = hipcub::DeviceRadixSort::SortPairsDescending(
                        workspace, workspace_size, keys_in + N * i,
                        keys_out + N * i, values_in + N * i, values_out + N * i,
                        N, begin_bit, end_bit, stream);
                hip_check(err);
                if (!keys_in) {
                    return workspace_size;
                }
            }
        }
    }
    return workspace_size;
}

size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
                                           bool is_ascending,
                                           bool iptr_src_given) {
    size_t size = 0;
    switch (dtype.enumv().ev) {
#define cb(ctype)                                             \
    case DTypeTrait<ctype>::enumv:                            \
        size = get_sort_workspace<ctype>(M, N, is_ascending); \
        break;
        ARGSORT_FOREACH_CTYPE(cb)
#undef cb
        default:
            megdnn_throw("argsort only supports float, int32 and float16");
    }
    if (!iptr_src_given) {
        size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);
    }
    return size;
}

template <typename dtype>
void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr,
                      void* workspace, uint32_t M, uint32_t N,
                      bool is_ascending, hipStream_t stream,
                      const int* iptr_src) {
    size_t wk_size = get_sort_workspace<dtype>(M, N, is_ascending);
    if (!iptr_src) {
        int* ptr = reinterpret_cast<int*>(static_cast<uint8_t*>(workspace) +
                                          DIVUP(wk_size, sizeof(float)) *
                                                  sizeof(float));
        kern_arange<<<DIVUP(N * M, 512), 512, 0, stream>>>(ptr, M * N, N);
        iptr_src = ptr;
    }

    if (use_bitonic(M, N)) {
        hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending,
                                stream));
    } else {
        cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src,
                       iptr, M, N, 0, sizeof(float)*8, stream);
    }
}

namespace megdnn {
namespace rocm {

#define INST_CUB_SORT(dtype)                                                 \
template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool,  \
                                    void*, size_t, const dtype*, dtype*,     \
                                    const dtype*, dtype*, uint32_t, uint32_t,\
                                    int, int, hipStream_t);

#define INST_FORWARD(dtype)                                                  \
template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*,     \
                                    uint32_t, uint32_t, bool, hipStream_t,   \
                                    const int*);
                                    
ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t)
INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT
#undef INST_FORWARD
}
}  // namespace megdnn
// vim: ft=rocm syntax=rocm.doxygen