megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/cuda/dropout/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "src/cuda/dropout/opr_impl.h"

namespace megdnn {
namespace cuda {

using Param = megdnn::Dropout::Param;

struct DropoutTensorDesc : public TensorDesc {
public:
    DropoutTensorDesc(const TensorLayout& layout) : TensorDesc() {
        set_dropout_desc(layout);
    }
    void set_dropout_desc(const TensorLayout& layout) {
        cudnnDataType_t cudnn_dtype;
        switch (layout.dtype.enumv()) {
            case DTypeEnum::Float32:
                cudnn_dtype = CUDNN_DATA_FLOAT;
                break;
            case DTypeEnum::Float16:
                cudnn_dtype = CUDNN_DATA_HALF;
                break;
            default:
                megdnn_throw("dtype must be float16/float32");
        }
        cudnn_check(cudnnSetTensor4dDescriptor(
                desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, 1, 1,
                layout.total_nr_elems()));
    }
};

size_t DropoutForwardImpl::get_mask_size_in_bytes(const TensorLayout& inp) {
    size_t reserve_space_size_in_bytes = 0;
    DropoutTensorDesc ddesc(inp);
    cudnn_check(
            cudnnDropoutGetReserveSpaceSize(ddesc.desc, &reserve_space_size_in_bytes));
    return reserve_space_size_in_bytes;
}

void DropoutForwardImpl::exec(
        _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
        _megdnn_workspace workspace) {
    check_exec(inp.layout, oup.layout, mask.layout, workspace.size);
    uint64_t seed = param().seed;
    float drop_prob = param().drop_prob;

    if (!dropout_status.initialized()) {
        dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob);
    }
    if (dropout_status.drop_prob != drop_prob) {
        dropout_status.drop_prob = drop_prob;
        dropout_status.restore_desc(cudnn_handle(this->handle()));
    }
    megdnn_assert(dropout_status.seed == seed);

    DropoutTensorDesc inp_desc(inp.layout), oup_desc(oup.layout);
    auto&& op_desc = dropout_status.desc;

    cudnn_check(cudnnDropoutForward(
            cudnn_handle(this->handle()), op_desc.desc, inp_desc.desc, inp.raw_ptr(),
            oup_desc.desc, oup.raw_ptr(), mask.raw_ptr(),
            mask.layout.total_nr_elems()));
}

void DropoutBackwardImpl::exec(
        _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
        _megdnn_workspace workspace) {
    check_exec(doup.layout, mask.layout, dinp.layout, workspace.size);

#if CUDNN_VERSION >= 7000
    size_t status_size_in_bytes = 0;
    cudnn_check(cudnnDropoutGetStatesSize(
            cudnn_handle(this->handle()), &status_size_in_bytes));

    DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout);
    op_desc.restore(
            cudnn_handle(this->handle()), param().drop_prob, nullptr,
            status_size_in_bytes, 0);
    cudnn_check(cudnnDropoutBackward(
            cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(),
            dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(),
            mask.layout.total_nr_elems()));
#else
    uint64_t seed = param().seed;
    float drop_prob = param().drop_prob;

    if (!dropout_status.initialized()) {
        dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob);
    }
    if (dropout_status.drop_prob != drop_prob) {
        dropout_status.drop_prob = drop_prob;
        dropout_status.restore_desc(cudnn_handle(this->handle()));
    }

    auto&& op_desc = dropout_status.desc;
    DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout);

    cudnn_check(cudnnDropoutBackward(
            cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(),
            dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(),
            mask.layout.total_nr_elems()));
#endif
}

}  // namespace cuda
}  // namespace megdnn
// vim: syntax=cpp.doxygen