megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/local_share/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void LocalShareBase::deduce_layout_fwd(
        const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
    using Mode = LocalShare::Param::Mode;
    auto errmsg =
            megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
            megdnn_layout_msg(dst) + ", " +
            "is_xcorr=" + std::to_string((param().mode == Mode::CROSS_CORRELATION)) +
            ", " + "pad_h=" + std::to_string(param().pad_h) + ", " +
            "pad_w=" + std::to_string(param().pad_w) + ", " +
            "stride_h=" + std::to_string(param().stride_h) + ", " +
            "stride_w=" + std::to_string(param().stride_w) + ", " +
            "dilate_h=" + std::to_string(param().dilate_h) + ", " +
            "dilate_w=" + std::to_string(param().dilate_w) + ", " +
            "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + ", " +
            "spatial_groups_w=" + std::to_string(param().spatial_groups_w);
    auto errmsg_c = errmsg.c_str();
    MEGDNN_MARK_USED_VAR(errmsg_c);

    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
    using Param = LocalShare::Param;
    using Sparse = Param::Sparse;
    using Format = Param::Format;
    using ComputeMode = Param::ComputeMode;
    megdnn_assert(
            param().format == Format::NCHW, "local shared only support NCHW format");
    megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
    megdnn_assert(
            (filter.ndim == 6_z && param().sparse == Sparse::DENSE) ||
                    (filter.ndim == 7_z && param().sparse == Sparse::GROUP),
            "%s", errmsg_c);
    megdnn_assert(
            param().dilate_h == 1 && param().dilate_w == 1,
            "dilated local shared is not supported");
    megdnn_assert(
            src.dtype == dtype::Float32() &&
                    param().computeMode == ComputeMode::DEFAULT,
            "local shared only support float32");

    size_t n = src[0], ci = src[1], hi = src[2], wi = src[3];
    size_t sgh = param().spatial_groups_h, sgw = param().spatial_groups_w;
    size_t groups = 1;
    size_t weights_shp_pos = 0;
    if (param().sparse == Sparse::GROUP) {
        groups = filter[0];
        weights_shp_pos = 1;
    }
    megdnn_assert(
            sgh == filter[weights_shp_pos] && sgw == filter[weights_shp_pos + 1],
            "spatial groups in filter tensor mismatch with those "
            "provided in parameter %s",
            errmsg_c);
    size_t fh = filter[weights_shp_pos + 3], fw = filter[weights_shp_pos + 4],
           co = filter[weights_shp_pos + 5] * groups;
    megdnn_assert(
            filter[weights_shp_pos + 2] * groups == ci,
            "input channels of src and filter mismatch %s", errmsg_c);
    size_t sh = param().stride_h;
    size_t sw = param().stride_w;
    size_t ph = param().pad_h;
    size_t pw = param().pad_w;
    size_t ho = infer_conv_shape(hi, fh, sh, ph), wo = infer_conv_shape(wi, fw, sw, pw);
    megdnn_assert(
            ho % sgh == 0 && wo % sgw == 0,
            "height and width of output cannot be divided by spatial groups %s",
            errmsg_c);
    dst = TensorLayout{{n, co, ho, wo}, src.dtype};
}

void LocalShareBase::check_layout_fwd(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
    TensorLayout dst_expected;
    megdnn_assert_eq_dtype(src, filter);
    megdnn_assert_eq_dtype(src, dst);
    deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);

    megdnn_assert(src.dtype == dtype::Float32());
}

void LocalShareForward::deduce_layout(
        const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
    deduce_layout_fwd(src, filter, dst);
}

void LocalShareForward::check_exec(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        size_t workspace_in_bytes) {
    check_layout_fwd(src, filter, dst);
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void LocalShareBackwardData::deduce_layout(
        const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
    using Mode = LocalShare::Param::Mode;
    auto errmsg =
            megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(diff) + ", " +
            megdnn_layout_msg(grad) + ", " +
            "is_xcorr=" + std::to_string((param().mode == Mode::CROSS_CORRELATION)) +
            ", " + "pad_h=" + std::to_string(param().pad_h) + ", " +
            "pad_w=" + std::to_string(param().pad_w) + ", " +
            "stride_h=" + std::to_string(param().stride_h) + ", " +
            "stride_w=" + std::to_string(param().stride_w) + ", " +
            "dilate_h=" + std::to_string(param().dilate_h) + ", " +
            "dilate_w=" + std::to_string(param().dilate_w) + ", " +
            "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + ", " +
            "spatial_groups_w=" + std::to_string(param().spatial_groups_w);
    auto errmsg_c = errmsg.c_str();
    MEGDNN_MARK_USED_VAR(errmsg_c);

    megdnn_assert_contiguous(filter);
    megdnn_assert_contiguous(diff);
    using Param = LocalShare::Param;
    using Sparse = Param::Sparse;
    using Format = Param::Format;
    using ComputeMode = Param::ComputeMode;
    megdnn_assert(
            param().format == Format::NCHW, "local shared only support NCHW format");
    megdnn_assert(
            (filter.ndim == 6_z && param().sparse == Sparse::DENSE) ||
                    (filter.ndim == 7_z && param().sparse == Sparse::GROUP),
            "%s", errmsg_c);
    megdnn_assert(diff.ndim == 4_z, "%s", errmsg_c);
    megdnn_assert(
            param().dilate_h == 1 && param().dilate_w == 1,
            "dilated local shared is not supported");
    megdnn_assert(
            diff.dtype == dtype::Float32() &&
                    param().computeMode == ComputeMode::DEFAULT,
            "local shared only support float32");

    size_t n = diff[0], co = diff[1], ho = diff[2], wo = diff[3];
    size_t sgh = param().spatial_groups_h, sgw = param().spatial_groups_w;
    megdnn_assert(
            ho % sgh == 0 && wo % sgw == 0,
            "height and width of output cannot be divided by spatial groups %s",
            errmsg_c);
    size_t groups = 1;
    size_t weights_shp_pos = 0;
    if (param().sparse == Sparse::GROUP) {
        groups = filter[0];
        weights_shp_pos = 1;
    }
    megdnn_assert(
            sgh == filter[weights_shp_pos] && sgw == filter[weights_shp_pos + 1],
            "spatial groups in filter tensor mismatch with those "
            "provided in parameter %s",
            errmsg_c);
    size_t ci = filter[weights_shp_pos + 2] * groups, fh = filter[weights_shp_pos + 3],
           fw = filter[weights_shp_pos + 4];
    megdnn_assert(
            filter[weights_shp_pos + 5] * groups == co,
            "input channels of src and filter mismatch %s", errmsg_c);
    size_t sh = param().stride_h;
    size_t sw = param().stride_w;
    size_t ph = param().pad_h;
    size_t pw = param().pad_w;

    auto deduce = [&errmsg_c](size_t out, size_t filter, size_t stride, size_t pad) {
        MEGDNN_MARK_USED_VAR(errmsg_c);
        auto i = (out - 1) * stride + filter;
        megdnn_assert(i > pad * 2, "%s", errmsg_c);
        return i - pad * 2;
    };
    grad.ndim = diff.ndim;
    grad[0] = n;
    grad[1] = ci;
    grad[2] = deduce(ho, fh, sh, ph);
    grad[3] = deduce(wo, fw, sw, pw);
    grad.init_contiguous_stride();
    grad.dtype = diff.dtype;
}

void LocalShareBackwardData::check_exec(
        const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
    auto filter_dtype = filter.dtype, diff_dtype = diff.dtype, grad_dtype = grad.dtype;
    megdnn_assert(
            filter_dtype == dtype::Float32() && filter_dtype == diff_dtype &&
            filter_dtype == grad_dtype);
    check_layout_fwd(grad, filter, diff);
    auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void LocalShareBackwardFilter::check_exec(
        const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
    auto src_dtype = src.dtype, diff_dtype = diff.dtype, grad_dtype = grad.dtype;
    megdnn_assert(
            src_dtype == dtype::Float32() && src_dtype == diff_dtype &&
            src_dtype == grad_dtype);
    check_layout_fwd(src, grad, diff);
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen