#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
using namespace megdnn;
using CanonizedFilterMeta = DeformableConvBase::CanonizedFilterMeta;
namespace {
template <typename Param>
std::string get_errmsg(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& dst, const Param& param) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
MEGDNN_MARK_USED_VAR(dst);
return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
megdnn_layout_msg(offset) + ", " + megdnn_layout_msg(mask) + ", " +
megdnn_layout_msg(dst) + ", " + "only support nchw" + ", " +
"group=" + std::to_string(param.group) + ", " +
"deformable_group=" + std::to_string(param.deformable_group) + ", " +
"pad_h=" + std::to_string(param.pad_h) + ", " +
"pad_w=" + std::to_string(param.pad_w) + ", " +
"stride_h=" + std::to_string(param.stride_h) + ", " +
"stride_w=" + std::to_string(param.stride_w) + ", " +
"dilate_h=" + std::to_string(param.dilate_h) + ", " +
"dilate_w=" + std::to_string(param.dilate_w);
}
template <typename Param>
void make_canonized_filter_meta_nchw(
size_t src_ndim, const TensorLayout& filter, const Param& param,
CanonizedFilterMeta& ret) {
megdnn_assert(
param.mode == Param::Mode::CROSS_CORRELATION,
"only support CROSS_CORRELATION mode");
megdnn_assert(
param.format == Param::Format::NCHW, "only support nchw input layout");
size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
flt_start = 0, flt_spatial_start = 2;
ocpg_pos = 0, icpg_pos = 1;
if (param.sparse == Param::Sparse::GROUP)
flt_start = 1;
ret.spatial_ndim = src_ndim - 2;
megdnn_assert(
ret.spatial_ndim == 2,
"only 2D convolution is supported, and imput should be 4-dim; "
"got input dim = %zu",
src_ndim);
ret.ocpg = filter[flt_start + ocpg_pos];
ret.icpg = filter[flt_start + icpg_pos];
auto dilation = ret.dilation;
for (size_t i = 0; i < ret.spatial_ndim; ++i) {
megdnn_assert(
dilation[i] > 0, "invalid dilation on spatial dim %zu, %u", i,
dilation[i]);
ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
}
}
}
namespace megdnn {
CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta(
size_t src_ndim, const TensorLayout& filter, const TensorLayout& offset) const {
megdnn_assert_contiguous(filter);
CanonizedFilterMeta ret;
ret.group = 1;
ret.dtype = filter.dtype;
ret.stride[0] = param().stride_h;
ret.stride[1] = param().stride_w;
ret.padding[0] = param().pad_h;
ret.padding[1] = param().pad_w;
ret.dilation[0] = param().dilate_h;
ret.dilation[1] = param().dilate_w;
if (param().sparse == Param::Sparse::GROUP) {
megdnn_assert(filter.ndim == 5, "filter dim should be 5 for group conv");
ret.group = filter[0];
}
make_canonized_filter_meta_nchw(src_ndim, filter, param(), ret);
auto fh = ret.spatial[0];
auto fw = ret.spatial[1];
ret.deformable_group = offset[1] / (2 * fh * fw);
return ret;
}
void DeformableConvBase::deduce_layout_fwd(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
const TensorLayout& mask, TensorLayout& dst) {
megdnn_assert(
im.ndim == 4, "invalid src layout: %s", megdnn_layout_msg(im).c_str());
megdnn_assert(
filter.ndim == 4 || filter.ndim == 5, "invalid filter layout: %s",
megdnn_layout_msg(filter).c_str());
megdnn_assert(
offset.ndim == 4, "invalid offset layout: %s",
megdnn_layout_msg(offset).c_str());
megdnn_assert(
mask.ndim == 4, "invalid mask layout: %s", megdnn_layout_msg(mask).c_str());
size_t n = im.shape[0], ic = im.shape[1];
size_t ih = im.shape[2], iw = im.shape[3];
size_t dh = param().dilate_h, dw = param().dilate_w;
size_t ph = param().pad_h, pw = param().pad_w;
size_t sh = param().stride_h, sw = param().stride_w;
auto&& fm = make_canonized_filter_meta(im.ndim, filter, offset);
size_t fh = fm.spatial[0], fw = fm.spatial[1];
size_t kh = 1 + (fh - 1) * dh;
size_t kw = 1 + (fw - 1) * dw;
size_t group = fm.group;
size_t deformable_group = fm.deformable_group;
size_t icpg = fm.icpg, ocpg = fm.ocpg;
size_t oc = group * ocpg;
size_t oh = (ih + ph * 2 - kh) / sh + 1;
size_t ow = (iw + pw * 2 - kw) / sw + 1;
megdnn_assert(
group > 0 && deformable_group > 0, "group and deformable group should > 0");
megdnn_assert(ic == icpg * group, "im ic != group * icpg of filter");
megdnn_assert(ic % deformable_group == 0, "ic %% deformable_group != 0");
megdnn_assert(oc % deformable_group == 0, "oc %% deformable_group != 0");
megdnn_assert(
(offset[1] % (2 * fh * fw) == 0) && (mask[1] % (fh * fw) == 0),
"invalid deformable group deduced from offset(%s) or mask(%s)",
megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());
megdnn_assert(
(offset[1] / (2 * fh * fw)) == (mask[1] / (fh * fw)),
"offset(%s) and mask(%s) should have same deformable group",
megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());
megdnn_assert(
(offset[2] == mask[2]) && (offset[3] == mask[3]),
"offset(%s) and mask(%s) should have same spatial dim",
megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());
megdnn_assert(oh == offset[2], "deduced oh(%zu) != offset oh(%zu)", oh, offset[2]);
megdnn_assert(ow == offset[3], "deduced ow(%zu) != offset ow(%zu)", ow, offset[3]);
dst.ndim = 4;
dst = {{n, oc, oh, ow}, im.dtype};
}
void DeformableConvBase::check_layout_fwd(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& dst) {
auto& im_dtype = im.dtype;
TensorLayout dst_expected;
megdnn_assert(
im_dtype.enumv() == DTypeEnum::Float32,
"DeformableConv only support float32 input");
megdnn_assert_eq_dtype(im, dst);
megdnn_assert_eq_dtype(im, filter);
megdnn_assert_eq_dtype(im, dst);
megdnn_assert_eq_dtype(im, offset);
megdnn_assert_eq_dtype(im, mask);
deduce_layout_fwd(im, filter, offset, mask, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);
}
void DeformableConvForward::deduce_layout(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
const TensorLayout& mask, TensorLayout& dst) {
deduce_layout_fwd(im, filter, offset, mask, dst);
return;
}
CanonizedFilterMeta DeformableConvForward::check_exec(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& dst, size_t workspace_in_bytes) {
auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
auto required_workspace_in_bytes =
get_workspace_in_bytes(im, filter, offset, mask, dst);
check_layout_fwd(im, filter, offset, mask, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
return ret;
}
CanonizedFilterMeta DeformableConvBackwardFilter::check_exec(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& filter_grad,
size_t workspace_in_bytes) {
check_layout_fwd(im, filter_grad, offset, mask, out_grad);
megdnn_assert_eq_dtype(im, filter_grad);
auto ret = make_canonized_filter_meta(im.ndim, filter_grad, offset);
auto required_workspace_in_bytes =
get_workspace_in_bytes(im, offset, mask, out_grad, filter_grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
return ret;
}
CanonizedFilterMeta DeformableConvBackwardData::check_exec(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& im_grad, const TensorLayout& offset_grad,
const TensorLayout& mask_grad, size_t workspace_in_bytes) {
check_layout_fwd(im, filter, offset, mask, out_grad);
megdnn_assert_eq_dtype(im, im_grad);
megdnn_assert_eq_dtype(im, offset_grad);
megdnn_assert_eq_dtype(im, mask_grad);
megdnn_assert_eq_shape(im, im_grad);
megdnn_assert_eq_shape(offset, offset_grad);
megdnn_assert_eq_shape(mask, mask_grad);
auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
auto required_workspace_in_bytes = get_workspace_in_bytes(
im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
return ret;
}
}