#pragma once
#include "./opr_impl.h"
#include "src/common/algo_chooser.h"
#include "src/common/utils.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
namespace convolution3d {
using CanonizedFilterMeta = Convolution3DForward::CanonizedFilterMeta;
struct ForwardSizeArgs {
HandleImpl* handle;
const TensorLayout* src_layout;
const TensorLayout* filter_layout;
CanonizedFilterMeta filter_meta;
const TensorLayout* dst_layout;
param::Convolution3D::DataType data_type;
};
bool is_cudnn_supported(const ForwardSizeArgs& args);
struct CUDNNForwardDescs {
Tensor3DDesc src_desc, dst_desc;
Filter3DDesc filter_desc;
Conv3DDesc conv_desc;
void set(
const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, const param::Convolution3D& param) {
src_desc.set(src);
filter_desc.set(filter);
dst_desc.set(dst);
conv_desc.set(param, filter.group);
}
};
struct CUDNNBwdDataDescs {
Tensor3DDesc diff_desc, grad_desc;
Filter3DDesc filter_desc;
Conv3DDesc conv_desc;
void set(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, const param::Convolution3D& param) {
filter_desc.set(filter);
diff_desc.set(diff);
grad_desc.set(grad);
conv_desc.set(param, filter.group);
}
};
struct CUDNNBwdFilterDescs {
Tensor3DDesc diff_desc, src_desc;
Filter3DDesc grad_desc;
Conv3DDesc conv_desc;
void set(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, const param::Convolution3D& param) {
src_desc.set(src);
diff_desc.set(diff);
grad_desc.set(grad);
conv_desc.set(param, grad.group);
}
};
void flip_filter(
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr);
inline bool cudnn_get_convolution_fwd_algo_helper(
cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc,
const cudnnFilterDescriptor_t w_desc,
const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t y_desc, size_t workspace_limit_in_bytes,
cudnnConvolutionFwdAlgo_t* algo, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7
int algo_max_count = 0;
cudnn_check(
cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &algo_max_count));
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count);
int algo_count = 0;
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7(
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count,
&algo_count, algo_perf.data()));
for (int i = 0; i < algo_count; ++i) {
if (algo_perf[i].algo ==
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING)
continue;
size_t workspace_size = 0;
cudnn_check(cudnnGetConvolutionForwardWorkspaceSize(
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_perf[i].algo,
&workspace_size));
if (workspace_size > workspace_limit_in_bytes)
continue;
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo;
return true;
} else {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
*algo = algo_perf[i].algo;
return true;
}
}
}
return false;
#else
cudnn_check(cudnnGetConvolutionForwardAlgorithm(
cudnn_handle, x_desc, w_desc, conv_desc, y_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_limit_in_bytes,
algo));
return true;
#endif
}
inline bool cudnn_get_convolution_bwd_data_algo_helper(
cudnnHandle_t cudnn_handle, const cudnnFilterDescriptor_t w_desc,
const cudnnTensorDescriptor_t dy_desc,
const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t dx_desc, size_t workspace_limit_in_bytes,
cudnnConvolutionBwdDataAlgo_t* algo, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7
int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
cudnn_handle, &algo_max_count));
SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(algo_max_count);
int algo_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7(
cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, algo_max_count,
&algo_count, algo_perf.data()));
for (int i = 0; i < algo_count; ++i) {
if (algo_perf[i].algo ==
cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING)
continue;
size_t workspace_size = 0;
cudnn_check(cudnnGetConvolutionBackwardDataWorkspaceSize(
cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, algo_perf[i].algo,
&workspace_size));
if (workspace_size > workspace_limit_in_bytes)
continue;
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo;
return true;
} else {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
*algo = algo_perf[i].algo;
return true;
}
}
}
return false;
#else
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm(
cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_limit_in_bytes, algo));
return true;
#endif
}
inline bool cudnn_get_convolution_bwd_filter_algo_helper(
cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc,
const cudnnTensorDescriptor_t dy_desc,
const cudnnConvolutionDescriptor_t conv_desc,
const cudnnFilterDescriptor_t dw_desc, size_t workspace_limit_in_bytes,
cudnnConvolutionBwdFilterAlgo_t* algo, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7
int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
cudnn_handle, &algo_max_count));
SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(algo_max_count);
int algo_count = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, algo_max_count,
&algo_count, algo_perf.data()));
for (int i = 0; i < algo_count; ++i) {
if (algo_perf[i].algo == cudnnConvolutionBwdFilterAlgo_t::
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING)
continue;
size_t workspace_size = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize(
cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, algo_perf[i].algo,
&workspace_size));
if (workspace_size > workspace_limit_in_bytes)
continue;
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo;
return true;
} else {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
*algo = algo_perf[i].algo;
return true;
}
}
}
return false;
#else
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm(
cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_limit_in_bytes, algo));
return true;
#endif
}
} } }