megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/atlas/megcore/atlas_computing_context.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "megcore.h"

#include "src/atlas//megcore/computing_context.hpp"
#include "src/atlas/utils.h"
#include "src/common/utils.h"

using namespace megcore;
using namespace megcore::atlas;

AtlasComputingContext::AtlasComputingContext(
        megcoreDeviceHandle_t dev_handle, unsigned int flags, const AtlasContext& ctx)
        : ComputingContext(dev_handle, flags),
          m_own_stream{ctx.stream == nullptr},
          m_ctx{ctx} {
    megcorePlatform_t platform;
    megcoreGetPlatform(dev_handle, &platform);
    megdnn_assert(platform == megcorePlatformAtlas);
    if (m_own_stream) {
        acl_check(aclrtCreateStream(&m_ctx.stream));
    }
}

AtlasComputingContext::~AtlasComputingContext() {
    if (m_own_stream) {
        acl_check(aclrtDestroyStream(m_ctx.stream));
    }
}

void AtlasComputingContext::memcpy(
        void* dst, const void* src, size_t size_in_bytes, megcoreMemcpyKind_t kind) {
    switch (kind) {
        case megcoreMemcpyDeviceToHost:
            acl_check(aclrtMemcpy(
                    dst, size_in_bytes, src, size_in_bytes, ACL_MEMCPY_DEVICE_TO_HOST));
            break;
        case megcoreMemcpyHostToDevice:
            acl_check(aclrtMemcpy(
                    dst, size_in_bytes, src, size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE));
            break;
        case megcoreMemcpyDeviceToDevice:
            // async d2d is always faster than sync d2d because of SDMA
            acl_check(aclrtMemcpyAsync(
                    dst, size_in_bytes, src, size_in_bytes, ACL_MEMCPY_DEVICE_TO_DEVICE,
                    m_ctx.stream));
            break;
        default:
            megdnn_throw("bad atlas memcpy kind");
    }
}

void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
    acl_check(aclrtSynchronizeStream(m_ctx.stream));
    acl_check(aclrtMemset(dst, size_in_bytes, value, size_in_bytes));
}

void AtlasComputingContext::synchronize() {
    acl_check(aclrtSynchronizeStream(m_ctx.stream));
}

// vim: syntax=cpp.doxygen