megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/deformable_conv.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

using namespace megdnn;

using CanonizedFilterMeta = DeformableConvBase::CanonizedFilterMeta;

namespace {

template <typename Param>
std::string get_errmsg(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& offset,
        const TensorLayout& mask, const TensorLayout& dst, const Param& param) {
    MEGDNN_MARK_USED_VAR(src);
    MEGDNN_MARK_USED_VAR(filter);
    MEGDNN_MARK_USED_VAR(dst);
    return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
           megdnn_layout_msg(offset) + ", " + megdnn_layout_msg(mask) + ", " +
           megdnn_layout_msg(dst) + ", " + "only support nchw" + ", " +
           "group=" + std::to_string(param.group) + ", " +
           "deformable_group=" + std::to_string(param.deformable_group) + ", " +
           "pad_h=" + std::to_string(param.pad_h) + ", " +
           "pad_w=" + std::to_string(param.pad_w) + ", " +
           "stride_h=" + std::to_string(param.stride_h) + ", " +
           "stride_w=" + std::to_string(param.stride_w) + ", " +
           "dilate_h=" + std::to_string(param.dilate_h) + ", " +
           "dilate_w=" + std::to_string(param.dilate_w);
}

template <typename Param>
void make_canonized_filter_meta_nchw(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        CanonizedFilterMeta& ret) {
    megdnn_assert(
            param.mode == Param::Mode::CROSS_CORRELATION,
            "only support CROSS_CORRELATION mode");

    megdnn_assert(
            param.format == Param::Format::NCHW, "only support nchw input layout");

    size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;

    flt_start = 0, flt_spatial_start = 2;
    ocpg_pos = 0, icpg_pos = 1;

    if (param.sparse == Param::Sparse::GROUP)
        flt_start = 1;

    ret.spatial_ndim = src_ndim - 2;

    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and imput should be 4-dim; "
            "got input dim = %zu",
            src_ndim);

    ret.ocpg = filter[flt_start + ocpg_pos];
    ret.icpg = filter[flt_start + icpg_pos];

    auto dilation = ret.dilation;

    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu, %u", i,
                dilation[i]);
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

}  // namespace

namespace megdnn {

CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta(
        size_t src_ndim, const TensorLayout& filter, const TensorLayout& offset) const {
    megdnn_assert_contiguous(filter);

    CanonizedFilterMeta ret;
    ret.group = 1;
    ret.dtype = filter.dtype;
    ret.stride[0] = param().stride_h;
    ret.stride[1] = param().stride_w;
    ret.padding[0] = param().pad_h;
    ret.padding[1] = param().pad_w;
    ret.dilation[0] = param().dilate_h;
    ret.dilation[1] = param().dilate_w;

    if (param().sparse == Param::Sparse::GROUP) {
        megdnn_assert(filter.ndim == 5, "filter dim should be 5 for group conv");
        ret.group = filter[0];
    }

    make_canonized_filter_meta_nchw(src_ndim, filter, param(), ret);

    auto fh = ret.spatial[0];
    auto fw = ret.spatial[1];

    ret.deformable_group = offset[1] / (2 * fh * fw);

    return ret;
}

void DeformableConvBase::deduce_layout_fwd(
        const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
        const TensorLayout& mask, TensorLayout& dst) {
    // im shape: (n, IC, IH, IW)
    megdnn_assert(
            im.ndim == 4, "invalid src layout: %s", megdnn_layout_msg(im).c_str());
    // filter shape: (OC, IC, FH, FW) or (g, OC/g, IC/g, FH, FW)
    megdnn_assert(
            filter.ndim == 4 || filter.ndim == 5, "invalid filter layout: %s",
            megdnn_layout_msg(filter).c_str());
    // offset shape: (N, 2*dg*FH*FW, OH, OW)
    megdnn_assert(
            offset.ndim == 4, "invalid offset layout: %s",
            megdnn_layout_msg(offset).c_str());
    // mask shape: (N, dg*FH*FW, OH, OW)
    megdnn_assert(
            mask.ndim == 4, "invalid mask layout: %s", megdnn_layout_msg(mask).c_str());

    size_t n = im.shape[0], ic = im.shape[1];
    size_t ih = im.shape[2], iw = im.shape[3];
    size_t dh = param().dilate_h, dw = param().dilate_w;
    size_t ph = param().pad_h, pw = param().pad_w;
    size_t sh = param().stride_h, sw = param().stride_w;

    auto&& fm = make_canonized_filter_meta(im.ndim, filter, offset);
    size_t fh = fm.spatial[0], fw = fm.spatial[1];

    size_t kh = 1 + (fh - 1) * dh;
    size_t kw = 1 + (fw - 1) * dw;

    size_t group = fm.group;
    size_t deformable_group = fm.deformable_group;

    size_t icpg = fm.icpg, ocpg = fm.ocpg;
    size_t oc = group * ocpg;
    size_t oh = (ih + ph * 2 - kh) / sh + 1;
    size_t ow = (iw + pw * 2 - kw) / sw + 1;

    megdnn_assert(
            group > 0 && deformable_group > 0, "group and deformable group should > 0");
    megdnn_assert(ic == icpg * group, "im ic != group * icpg of filter");
    megdnn_assert(ic % deformable_group == 0, "ic %% deformable_group != 0");
    megdnn_assert(oc % deformable_group == 0, "oc %% deformable_group != 0");

    megdnn_assert(
            (offset[1] % (2 * fh * fw) == 0) && (mask[1] % (fh * fw) == 0),
            "invalid deformable group deduced from offset(%s) or mask(%s)",
            megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());

    megdnn_assert(
            (offset[1] / (2 * fh * fw)) == (mask[1] / (fh * fw)),
            "offset(%s) and mask(%s) should have same deformable group",
            megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());

    megdnn_assert(
            (offset[2] == mask[2]) && (offset[3] == mask[3]),
            "offset(%s) and mask(%s) should have same spatial dim",
            megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());
    megdnn_assert(oh == offset[2], "deduced oh(%zu) != offset oh(%zu)", oh, offset[2]);
    megdnn_assert(ow == offset[3], "deduced ow(%zu) != offset ow(%zu)", ow, offset[3]);
    dst.ndim = 4;

    dst = {{n, oc, oh, ow}, im.dtype};
}
void DeformableConvBase::check_layout_fwd(
        const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
        const TensorLayout& mask, const TensorLayout& dst) {
    auto& im_dtype = im.dtype;
    TensorLayout dst_expected;
    megdnn_assert(
            im_dtype.enumv() == DTypeEnum::Float32,
            "DeformableConv only support float32 input");
    megdnn_assert_eq_dtype(im, dst);
    megdnn_assert_eq_dtype(im, filter);
    megdnn_assert_eq_dtype(im, dst);
    megdnn_assert_eq_dtype(im, offset);
    megdnn_assert_eq_dtype(im, mask);
    deduce_layout_fwd(im, filter, offset, mask, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
}

void DeformableConvForward::deduce_layout(
        const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
        const TensorLayout& mask, TensorLayout& dst) {
    deduce_layout_fwd(im, filter, offset, mask, dst);
    return;
}

CanonizedFilterMeta DeformableConvForward::check_exec(
        const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
        const TensorLayout& mask, const TensorLayout& dst, size_t workspace_in_bytes) {
    auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
    auto required_workspace_in_bytes =
            get_workspace_in_bytes(im, filter, offset, mask, dst);
    check_layout_fwd(im, filter, offset, mask, dst);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

CanonizedFilterMeta DeformableConvBackwardFilter::check_exec(
        const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask,
        const TensorLayout& out_grad, const TensorLayout& filter_grad,
        size_t workspace_in_bytes) {
    check_layout_fwd(im, filter_grad, offset, mask, out_grad);
    // check dtype
    megdnn_assert_eq_dtype(im, filter_grad);

    auto ret = make_canonized_filter_meta(im.ndim, filter_grad, offset);
    auto required_workspace_in_bytes =
            get_workspace_in_bytes(im, offset, mask, out_grad, filter_grad);

    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

CanonizedFilterMeta DeformableConvBackwardData::check_exec(
        const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset,
        const TensorLayout& mask, const TensorLayout& out_grad,
        const TensorLayout& im_grad, const TensorLayout& offset_grad,
        const TensorLayout& mask_grad, size_t workspace_in_bytes) {
    check_layout_fwd(im, filter, offset, mask, out_grad);

    // check dtype
    megdnn_assert_eq_dtype(im, im_grad);
    megdnn_assert_eq_dtype(im, offset_grad);
    megdnn_assert_eq_dtype(im, mask_grad);

    // check layout
    megdnn_assert_eq_shape(im, im_grad);
    megdnn_assert_eq_shape(offset, offset_grad);
    megdnn_assert_eq_shape(mask, mask_grad);

    auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
    auto required_workspace_in_bytes = get_workspace_in_bytes(
            im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad);

    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}
}  // namespace megdnn

// vim: syntax=cpp.doxygen