#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class DropoutDesc {
public:
DropoutDesc() { cudnn_check(cudnnCreateDropoutDescriptor(&desc)); }
~DropoutDesc() { cudnn_check(cudnnDestroyDropoutDescriptor(desc)); }
void set(
cudnnHandle_t handle, void* status, size_t states_size_in_bytes,
uint64_t seed, float drop_prob) {
cudnn_check(cudnnSetDropoutDescriptor(
desc, handle, drop_prob, status, states_size_in_bytes, seed));
}
void restore(
cudnnHandle_t handle, float drop_prob, void* status,
size_t states_size_in_bytes, uint64_t seed) {
#if CUDNN_VERSION >= 7000
cudnn_check(cudnnRestoreDropoutDescriptor(
desc, handle, drop_prob, status, states_size_in_bytes, 0));
#else
cudnn_check(cudnnSetDropoutDescriptor(
desc, handle, drop_prob, status, states_size_in_bytes, seed));
#endif
}
cudnnDropoutDescriptor_t desc;
};
class DropoutStatus {
void* status;
uint64_t status_size;
uint64_t seed;
float drop_prob;
DropoutDesc desc;
public:
DropoutStatus() {
status = nullptr;
status_size = 0;
}
~DropoutStatus() {
if (status != nullptr)
cuda_check(cudaFree(status));
}
void set(cudnnHandle_t handle, uint64_t seed, float drop_prob) {
this->seed = seed;
this->drop_prob = drop_prob;
cudnn_check(cudnnDropoutGetStatesSize(handle, &status_size));
cuda_check(cudaMalloc(&status, status_size));
desc.set(handle, status, status_size, seed, drop_prob);
}
void restore_desc(cudnnHandle_t handle) {
desc.restore(handle, drop_prob, status, status_size, seed);
}
bool initialized() { return status != nullptr; }
friend class DropoutForwardImpl;
friend class DropoutBackwardImpl;
};
class DropoutForwardImpl final : public DropoutForward {
DropoutStatus dropout_status;
public:
using DropoutForward::DropoutForward;
void exec(
_megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
_megdnn_workspace workspace) override;
size_t get_mask_size_in_bytes(const TensorLayout& inp) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
class DropoutBackwardImpl final : public DropoutBackward {
#if CUDNN_VERSION >= 7000
DropoutDesc op_desc;
#else
DropoutStatus dropout_status;
#endif
public:
using DropoutBackward::DropoutBackward;
void exec(
_megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
} }