megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/handle.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "hcc_detail/hcc_defs_prologue.h"

#include "src/common/handle_impl.h"
#include "src/common/version_symbol.h"

#include "src/rocm/handle.h"
#include "src/rocm/miopen_with_check.h"
#include "src/rocm/utils.h"

#include "src/rocm/adaptive_pooling/opr_impl.h"
#include "src/rocm/add_update/opr_impl.h"
#include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/argsort/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h"
#include "src/rocm/batched_matrix_mul/opr_impl.h"
#include "src/rocm/checksum/opr_impl.h"
#include "src/rocm/convolution/opr_impl.h"
#include "src/rocm/elemwise/opr_impl.h"
#include "src/rocm/eye/opr_impl.h"
#include "src/rocm/fill/opr_impl.h"
#include "src/rocm/indexing_multi_axis_vec/opr_impl.h"
#include "src/rocm/indexing_one_hot/opr_impl.h"
#include "src/rocm/linspace/opr_impl.h"
#include "src/rocm/matrix_mul/opr_impl.h"
#include "src/rocm/param_pack/opr_impl.h"
#include "src/rocm/pooling/opr_impl.h"
#include "src/rocm/powc/opr_impl.h"
#include "src/rocm/reduce/opr_impl.h"
#include "src/rocm/relayout/opr_impl.h"
#include "src/rocm/rng/opr_impl.h"
#include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/topk/opr_impl.h"
#include "src/rocm/type_cvt/opr_impl.h"

#include <hip/hip_version.h>
#include <miopen/version.h>

#include <cstring>

#define STR_HELPER(x) #x
#define STR(x)        STR_HELPER(x)

#define MIOPEN_VERSION_STR    \
    STR(MIOPEN_VERSION_MAJOR) \
    "." STR(MIOPEN_VERSION_MINOR) "." STR(MIOPEN_VERSION_PATCH)

#pragma message "compile with MIOpen " MIOPEN_VERSION_STR " "

#undef STR
#undef STR_HELPER

namespace megdnn {
std::unique_ptr<Handle> Handle::make_rocm_handle(
        megcoreComputingHandle_t computing_handle) {
    return std::make_unique<rocm::HandleImpl>(computing_handle);
}
template <typename Opr>
std::unique_ptr<Opr> Handle::create_rocm_operator() {
    return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>();
}
#define INST(opr) template std::unique_ptr<opr> Handle::create_rocm_operator();
MEGDNN_FOREACH_OPR_CLASS(INST)
#undef INST
}  // namespace megdnn

namespace megdnn {
namespace rocm {

HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
        : HandleImplHelper(comp_handle, HandleType::ROCM) {
    // Get megcore device handle
    megcoreDeviceHandle_t dev_handle;
    megcoreGetDeviceHandle(comp_handle, &dev_handle);
    int dev_id;
    megcoreGetDeviceID(dev_handle, &dev_id);
    if (dev_id < 0) {
        hip_check(hipGetDevice(&dev_id));
    }
    m_device_id = dev_id;
    hip_check(hipGetDeviceProperties(&m_device_prop, dev_id));
    // Get stream from MegCore computing handle.
    //! no version check
    megcore::getROCMContext(comp_handle, &m_megcore_context);
    rocblas_check(rocblas_create_handle(&m_rocblas_handle));
    //! must call miopenCreateWithStream() to create miopen handle, then the
    //! rocblas_handle of miopen will set to be the same stream , otherwise
    //! miopen create rocblas_handle with default stream
    miopen_check(miopenCreateWithStream(&m_miopen_handle, stream()));

    // Set stream for miopen and rocblas handles.
    rocblas_check(rocblas_set_stream(m_rocblas_handle, stream()));

    // Note that all rocblas scalars (alpha, beta) and scalar results such as
    // dot output resides at device side.
    rocblas_check(
            rocblas_set_pointer_mode(m_rocblas_handle, rocblas_pointer_mode_device));

    // init const scalars
    hip_check(hipMalloc(&m_const_scalars, sizeof(ConstScalars)));
    ConstScalars const_scalars_val;
    const_scalars_val.init();
    hip_check(hipMemcpyAsync(
            m_const_scalars, &const_scalars_val, sizeof(ConstScalars),
            hipMemcpyHostToDevice, stream()));
    hip_check(hipStreamSynchronize(stream()));
}

HandleImpl::~HandleImpl() noexcept {
    miopen_check(miopenDestroy(m_miopen_handle));
    rocblas_check(rocblas_destroy_handle(m_rocblas_handle));
    hip_check(hipFree(m_const_scalars));
}

void HandleImpl::ConstScalars::init() {
#if !MEGDNN_DISABLE_FLOAT16
    f16[0].megdnn_x = 0;
    f16[1].megdnn_x = 1;
#endif
    f32[0] = 0;
    f32[1] = 1;
    i32[0] = 0;
    i32[1] = 1;
}

template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
    megdnn_throw("unsupported rocm opr");
    return nullptr;
}

size_t HandleImpl::alignment_requirement() const {
    auto&& prop = m_device_prop;
    MEGDNN_MARK_USED_VAR(prop);
    //! for now, texture functions are not supported.
    return 1u;
}

bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
    // is contiguous or can be hold by
    // relayout::param::try_copy_2d/try_copy_last_contig
    return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
}

MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill);

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop

}  // namespace rocm
}  // namespace megdnn

MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH);
MEGDNN_VERSION_SYMBOL3(
        MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
// vim: syntax=cpp.doxygen