#include "src/naive/convpooling/opr_impl.h"
#include <cstring>
#include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
ConvPoolingForwardImpl::ConvPoolingForwardImpl(Handle* handle)
: ConvPoolingForward(handle) {
convFwd = new ConvolutionForwardImpl(this->handle());
poolFwd = new PoolingForwardImpl(this->handle());
return;
}
void ConvPoolingForwardImpl::setParamOfSublayers() {
Convolution::Param& cparam = convFwd->param();
cparam.pad_h = this->param().conv_pad_h;
cparam.pad_w = this->param().conv_pad_w;
cparam.stride_h = this->param().conv_stride_h;
cparam.stride_w = this->param().conv_stride_w;
if (this->param().convMode == ConvPoolingBase::Param::ConvMode::CONVOLUTION) {
cparam.mode = Convolution::Param::Mode::CONVOLUTION;
} else {
cparam.mode = Convolution::Param::Mode::CROSS_CORRELATION;
}
Pooling::Param& pparam = poolFwd->param();
pparam.window_h = this->param().pool_shape_h;
pparam.window_w = this->param().pool_shape_w;
pparam.stride_h = this->param().pool_stride_h;
pparam.stride_w = this->param().pool_stride_w;
pparam.pad_h = this->param().pool_pad_h;
pparam.pad_w = this->param().pool_pad_w;
if (this->param().poolMode == ConvPoolingBase::Param::PoolMode::AVERAGE) {
pparam.mode = PoolingBase::Param::Mode::AVERAGE;
} else {
pparam.mode = PoolingBase::Param::Mode::MAX;
}
}
void ConvPoolingForwardImpl::check_layout(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
TensorLayout& dst, size_t ) {
TensorLayout dst_expected;
this->deduce_layout(src, filter, bias, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert(bias.shape[1] == dst.shape[1]);
megdnn_assert(dst.shape[1] == filter.shape[0]);
return;
}
void ConvPoolingForwardImpl::deduce_layout(
const TensorLayout& srcl, const TensorLayout& filterl,
const TensorLayout& , TensorLayout& dstl) {
setParamOfSublayers();
convFwd->deduce_layout(srcl, filterl, conv_dst_layout);
poolFwd->deduce_layout(conv_dst_layout, dstl);
}
size_t ConvPoolingForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
const TensorLayout& ) {
TensorLayout tmp_layout;
this->deduce_layout(src, filter, bias, tmp_layout);
return conv_dst_layout.total_nr_elems() * sizeof(float);
}
void ConvPoolingForwardImpl::exec(
const _megdnn_in TensorND src, const _megdnn_in TensorND filter,
const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
_megdnn_out Workspace workspace) {
Workspace empty_wsp;
TensorND conv_dst{workspace.raw_ptr, conv_dst_layout};
check_layout(src.layout, filter.layout, bias.layout, dst.layout, workspace.size);
convFwd->exec(src, filter, conv_dst, nullptr, empty_wsp);
int conv_dst_batch = conv_dst.layout.shape[0];
int conv_dst_channel = conv_dst.layout.shape[1];
int chann_stride = conv_dst.layout.shape[2] * conv_dst.layout.shape[3];
float* conv_dst_ptr = conv_dst.ptr<float>();
for (int batch = 0; batch < conv_dst_batch; ++batch) {
for (int chan = 0; chan < conv_dst_channel; ++chan) {
float bias_val = bias.ptr<float>()[chan];
for (int i = 0; i < chann_stride; ++i, ++conv_dst_ptr) {
conv_dst_ptr[0] += bias_val;
}
}
}
nonlineFwd = new ElemwiseForwardImpl(this->handle());
switch (this->param().nonlineMode) {
case Param::NonlineMode::RELU:
nonlineFwd->param().mode = Elemwise::Param::Mode::RELU;
nonlineFwd->exec({conv_dst}, conv_dst);
break;
case Param::NonlineMode::SIGMOID:
nonlineFwd->param().mode = Elemwise::Param::Mode::SIGMOID;
nonlineFwd->exec({conv_dst}, conv_dst);
break;
case Param::NonlineMode::IDENTITY:
break;
default:
break;
}
poolFwd->exec(conv_dst, dst, empty_wsp);
}
} }