#include "./opr_impl.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
namespace batch_normalization {
BNTensorDescHolder::BNTensorDescHolder(
const TensorLayout& x, const ParamDim& param_dim, const FwdMode& fwd_mode) {
TensorShape xy_shape(x);
Format xy_format = Format::NCHW;
switch (param_dim) {
case ParamDim::DIM_11HW:
xy_shape.shape[0] = xy_shape.shape[0] * xy_shape.shape[1];
xy_shape.shape[1] = 1;
bn_mode = CUDNN_BATCHNORM_PER_ACTIVATION;
break;
case ParamDim::DIM_1CHW:
bn_mode = CUDNN_BATCHNORM_PER_ACTIVATION;
break;
case ParamDim::DIM_1C11:
bn_mode = CUDNN_BATCHNORM_SPATIAL;
break;
case ParamDim::DIM_111C:
bn_mode = CUDNN_BATCHNORM_SPATIAL;
xy_format = Format::NHWC;
#if CUDNN_VERSION >= 7410
if (fwd_mode == FwdMode::TRAINING) {
bn_mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
}
#endif break;
default:
megdnn_throw("Unknown param dim type of batch normalization.");
}
xy_desc.set(TensorLayout(xy_shape, x.dtype), xy_format);
param_desc.set(xy_desc.desc, bn_mode);
}
size_t get_reserve_size(
const cudnnHandle_t& handle, const BNTensorDescHolder& tensor_desc) {
#if CUDNN_VERSION >= 7410
size_t reserve_size;
cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
nullptr, tensor_desc.xy_desc.desc, &reserve_size));
return reserve_size;
#else
return 0;
#endif }
}
using batch_normalization::BNTensorDescHolder;
size_t BNForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
#if CUDNN_VERSION >= 7410
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
size_t workspace_size;
cudnn_check(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
tensor_desc.xy_desc.desc, tensor_desc.xy_desc.desc, tensor_desc.xy_desc.desc, tensor_desc.param_desc.desc, nullptr, &workspace_size));
return workspace_size;
#else
return 0;
#endif }
size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) {
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
return batch_normalization::get_reserve_size(
cudnn_handle(this->handle()), tensor_desc);
}
void BNForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias,
_megdnn_tensor_out mean, _megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out reserve, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(
src.layout, bn_scale.layout, bn_bias.layout, mean.layout, variance.layout,
batch_mean.layout, batch_inv_variance.layout, dst.layout, workspace.size,
reserve.layout.access_bytes());
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim, m_param.fwd_mode);
float alpha = 1.0f, beta = 0.0f;
switch (m_param.fwd_mode) {
case param::BN::FwdMode::TRAINING:
#if CUDNN_VERSION >= 7410
cudnn_check(cudnnBatchNormalizationForwardTrainingEx(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha,
&beta, tensor_desc.xy_desc.desc, src.raw_ptr(), nullptr, nullptr, tensor_desc.xy_desc.desc, dst.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor,
mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon,
batch_mean.raw_ptr(), batch_inv_variance.raw_ptr(), nullptr,
workspace.raw_ptr, workspace.size, reserve.raw_ptr(),
reserve.layout.access_bytes()));
#else
cudnn_check(cudnnBatchNormalizationForwardTraining(
handle, tensor_desc.bn_mode, &alpha, &beta,
tensor_desc.xy_desc.desc, src.raw_ptr(), tensor_desc.xy_desc.desc, dst.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor,
mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon,
batch_mean.raw_ptr(), batch_inv_variance.raw_ptr()));
#endif break;
case param::BN::FwdMode::INFERENCE:
cudnn_check(cudnnBatchNormalizationForwardInference(
handle, tensor_desc.bn_mode, &alpha, &beta,
tensor_desc.xy_desc.desc, src.raw_ptr(), tensor_desc.xy_desc.desc,
dst.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(),
bn_bias.raw_ptr(), mean.raw_ptr(), variance.raw_ptr(),
m_param.epsilon));
break;
default:
megdnn_throw("Unknown forward mode type of batch normalization.");
}
}
size_t BNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
#if CUDNN_VERSION >= 7410
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(x, m_param.param_dim, m_param.fwd_mode);
size_t workspace_size;
cudnn_check(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
tensor_desc.xy_desc.desc, tensor_desc.xy_desc.desc, tensor_desc.xy_desc.desc, nullptr, tensor_desc.xy_desc.desc, tensor_desc.param_desc.desc, nullptr, &workspace_size));
return workspace_size;
#else
return 0;
#endif }
size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) {
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
return batch_normalization::get_reserve_size(
cudnn_handle(this->handle()), tensor_desc);
}
void BNBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in reserve, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) {
check_exec(
x.layout, dy.layout, saved_batch_mean.layout,
saved_batch_inv_variance.layout, bn_scale.layout, d_bn_scale.layout,
d_bn_bias.layout, dx.layout, workspace.size, reserve.layout.access_bytes());
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim, m_param.fwd_mode);
float alpha = 1.0, beta = 0.0;
#if CUDNN_VERSION >= 7410
cudnn_check(cudnnBatchNormalizationBackwardEx(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, &alpha,
&beta, tensor_desc.xy_desc.desc,
x.raw_ptr(), nullptr, nullptr, tensor_desc.xy_desc.desc, dy.raw_ptr(), nullptr, nullptr, tensor_desc.xy_desc.desc, dx.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(), nullptr, d_bn_scale.raw_ptr(), d_bn_bias.raw_ptr(), m_param.epsilon, saved_batch_mean.raw_ptr(),
saved_batch_inv_variance.raw_ptr(), nullptr, workspace.raw_ptr,
workspace.size, reserve.raw_ptr(), reserve.layout.access_bytes()));
#else
cudnn_check(cudnnBatchNormalizationBackward(
handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta,
tensor_desc.xy_desc.desc, x.raw_ptr(), tensor_desc.xy_desc.desc, dy.raw_ptr(), tensor_desc.xy_desc.desc, dx.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(), d_bn_scale.raw_ptr(), d_bn_bias.raw_ptr(), m_param.epsilon, saved_batch_mean.raw_ptr(),
saved_batch_inv_variance.raw_ptr()));
#endif
}
} }