megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/fallback/resize/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "src/fallback/resize/opr_impl.h"
#include <vector>
#include "src/common/rounding_converter.cuh"
#include "src/fallback/handle.h"

using namespace megdnn;
using namespace fallback;

template <typename ctype>
void ResizeImpl::kern_fallback(const KernParam<ctype>& kern_param) {
    if (kern_param.format == Format::NHWC) {
        kern_fallback_nhwc(kern_param);
        return;
    }
    megdnn_assert(kern_param.format == Format::NCHW);

    UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
    rounding::RoundingConverter<ctype> output_converter;
    float scale_h = static_cast<float>(OH) / IH;
    float scale_w = static_cast<float>(OW) / IW;

    auto build_table = [this](InterpolationMode imode, float scale, int isize,
                              int osize) {
        std::vector<std::tuple<float, int, float, int>> table;
        rep(i, osize) {
            table.push_back(get_nearest_linear_coord(imode, scale, isize, i));
        }
        return table;
    };

    auto table_h = build_table(kern_param.imode, scale_h, IH, OH);
    auto table_w = build_table(kern_param.imode, scale_w, IW, OW);

    rep(n, N) {
        rep(c, static_cast<int>(C)) {
            rep(oh, OH) {
                float ah0, ah1, aw0, aw1;
                int ih0, ih1, iw0, iw1;

                std::tie(ah0, ih0, ah1, ih1) = table_h[oh];
                rep(ow, OW) {
                    std::tie(aw0, iw0, aw1, iw1) = table_w[ow];
                    dptr[c * OH * OW + oh * OW + ow] = output_converter(
                            sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 * aw0 +
                            sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 * aw1 +
                            sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 * aw0 +
                            sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 * aw1);
                }
            }
        }
        sptr += S_IN;
        dptr += C * OH * OW;
    }
}

template <typename ctype>
void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) {
    UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
    rounding::RoundingConverter<ctype> output_converter;
    float scale_h = static_cast<float>(OH) / IH;
    float scale_w = static_cast<float>(OW) / IW;

    auto build_table = [this](InterpolationMode imode, float scale, int isize,
                              int osize) {
        std::vector<std::tuple<float, int, float, int>> table;
        rep(i, osize) {
            table.push_back(get_nearest_linear_coord(imode, scale, isize, i));
        }
        return table;
    };
    auto table_h = build_table(kern_param.imode, scale_h, IH, OH);
    auto table_w = build_table(kern_param.imode, scale_w, IW, OW);

    rep(n, N) {
        rep(oh, OH) {
            float ah0, ah1, aw0, aw1;
            int ih0, ih1, iw0, iw1;

            std::tie(ah0, ih0, ah1, ih1) = table_h[oh];
            rep(ow, OW) {
                std::tie(aw0, iw0, aw1, iw1) = table_w[ow];
                rep(c, C) {
                    dptr[(oh * OW + ow) * C + c] = output_converter(
                            sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 +
                            sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 +
                            sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 +
                            sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1);
                }
            }
        }
        sptr += C * IH * IW;
        dptr += C * OH * OW;
    }
}

void ResizeImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
    check_exec(src.layout, dst.layout, workspace.size);
    if (param().format == param::Resize::Format::NCHW4 ||
        param().format == param::Resize::Format::NCHW44 ||
        param().format == param::Resize::Format::NCHW88 ||
        (param().format == param::Resize::Format::NCHW &&
         param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) {
        naive::ResizeImpl::exec(src, dst, workspace);
        return;
    }
    if ((param().format == param::Resize::Format::NCHW ||
         (src.layout[3] != 1 && src.layout[3] != 3)) ||
        (param().imode == param::Resize::InterpolationMode::LINEAR)) {
#define cb(dt, ct)                                                   \
    case DTypeTrait<dt>::enumv: {                                    \
        auto kparam = KernParam<ct>::from_tensors(                   \
                param().format, param().imode, src, dst, workspace); \
        MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam));         \
        return;                                                      \
    }

        switch (src.layout.dtype.enumv()) {
            cb(dtype::Float32, float);
            DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16));
            cb(dtype::Int8, int8_t);
            cb(dtype::QuantizedS8, int8_t);
            cb(dtype::Uint8, uint8_t);
            cb(dtype::Quantized8Asymm, uint8_t);
            default:
                megdnn_throw(ssprintf(
                                     "Unsupported input DType in Resize: %s",
                                     src.layout.dtype.name())
                                     .c_str());
                return;
        }

#undef cb
    }

    naive::ResizeImpl::exec(src, dst, workspace);
}

// vim: syntax=cpp.doxygen