#include "./opr_impl.h"
#include "src/rocm/utils.h"
namespace megdnn {
namespace rocm {
namespace batch_normalization {
void BNTensorDescHolder::setup(const TensorLayout& x, const ParamDim& param_dim) {
TensorShape xy_shape(x);
switch (param_dim) {
case ParamDim::DIM_11HW:
xy_shape.shape[0] = xy_shape.shape[0] * xy_shape.shape[1];
xy_shape.shape[1] = 1;
bn_mode = miopenBNPerActivation;
break;
case ParamDim::DIM_1CHW:
bn_mode = miopenBNPerActivation;
break;
case ParamDim::DIM_1C11:
bn_mode = miopenBNSpatial;
break;
default:
megdnn_throw("Unknown param dim type of batch normalization.");
}
xy_desc.set(TensorLayout(xy_shape, x.dtype));
param_desc.set(xy_desc.desc, bn_mode);
}
}
void BNForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias,
_megdnn_tensor_out mean, _megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(
src.layout, bn_scale.layout, bn_bias.layout, mean.layout, variance.layout,
batch_mean.layout, batch_inv_variance.layout, dst.layout, workspace.size);
auto handle = concrete_handle(this->handle())->miopen_handle();
m_tensor_desc.setup(src.layout, m_param.param_dim);
float alpha = 1.0f, beta = 0.0f;
switch (m_param.fwd_mode) {
case param::BN::FwdMode::TRAINING:
miopen_check(miopenBatchNormalizationForwardTraining(
handle, m_tensor_desc.bn_mode, &alpha, &beta,
m_tensor_desc.xy_desc.desc, src.raw_ptr(), m_tensor_desc.xy_desc.desc, dst.raw_ptr(), m_tensor_desc.param_desc.desc, bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor,
mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon,
batch_mean.raw_ptr(), batch_inv_variance.raw_ptr()));
break;
case param::BN::FwdMode::INFERENCE:
miopen_check(miopenBatchNormalizationForwardInference(
handle, m_tensor_desc.bn_mode, &alpha, &beta,
m_tensor_desc.xy_desc.desc, src.raw_ptr(),
m_tensor_desc.xy_desc.desc, dst.raw_ptr(),
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr(),
bn_bias.raw_ptr(), mean.raw_ptr(), variance.raw_ptr(),
m_param.epsilon));
break;
default:
megdnn_throw("Unknown forward mode type of batch normalization.");
}
}
void BNBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx, _megdnn_workspace workspace) {
check_exec(
x.layout, dy.layout, saved_batch_mean.layout,
saved_batch_inv_variance.layout, bn_scale.layout, d_bn_scale.layout,
d_bn_bias.layout, dx.layout, workspace.size);
auto handle = concrete_handle(this->handle())->miopen_handle();
m_tensor_desc.setup(x.layout, m_param.param_dim);
float alpha = 1.0, beta = 0.0;
miopen_check(miopenBatchNormalizationBackward(
handle, m_tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta,
m_tensor_desc.xy_desc.desc, x.raw_ptr(), m_tensor_desc.xy_desc.desc,
dy.raw_ptr(), m_tensor_desc.xy_desc.desc, dx.raw_ptr(),
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr(), d_bn_scale.raw_ptr(),
d_bn_bias.raw_ptr(), m_param.epsilon, saved_batch_mean.raw_ptr(),
saved_batch_inv_variance.raw_ptr()));
}
} }