megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/x86/conv_bias/int8/avx2_chanwsie_stride1.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h"
#include "src/x86/conv_bias/int8/avx2_chanwise_kern.h"
#include "src/x86/elemwise_op.h"

namespace megdnn {
namespace x86 {
namespace avx2_chanwise_stride1 {

template <size_t filter, BiasMode bias_mode, bool is_quantized, typename Op>
void conv_kimpl(
        const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
        const NCBKernIndex& ncb_index) {
    size_t OH = kern_param.osz[0];
    size_t OW = kern_param.osz[1];
    size_t IH2, IW2, OH2, OW2;
    get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
    bool need_src_copy_var = need_src_copy(kern_param);
    bool need_dst_copy_var = need_dst_copy(kern_param);
    bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;

    Op op = Op(1.0f, 4.0f);
    if (need_post_process) {
        float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
        float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
        op = Op(scale_bias, scale_dst);
    }
    size_t padding_group_size = IH2 * IW2;
    size_t workspace_group_id = ncb_index.thread_id;
    size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1];

    const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
    const int8_t* fptr = kern_param.filter<dt_int8>(group_id);
    void* dst = kern_param.dst<void>(batch_id, group_id);
    const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id);
    if (need_src_copy_var) {
        sptr = static_cast<int8_t*>(bundle.get(0)) +
               workspace_group_id * padding_group_size;
    }
    void* dptr = nullptr;
    int32_t* tptr = nullptr;
    if (need_dst_copy_var) {
        dptr = reinterpret_cast<void*>(
                reinterpret_cast<ptrdiff_t>(bundle.get(1)) +
                ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size());
    } else {
        dptr = dst;
    }

#define KERN_NEED_POST_PROCESS(filter)                                              \
    avx2_chanwise_direct_stride1_##filter##x##filter##_int8<bias_mode, true, Op>(   \
            sptr, fptr, bptr, tptr, static_cast<int8_t*>(dptr), IH2, IW2, OH2, OW2, \
            op)

#define KERN_NO_POST_PROCESS(filter)                                               \
    avx2_chanwise_direct_stride1_##filter##x##filter##_int8<bias_mode, false, Op>( \
            sptr, fptr, bptr, static_cast<int32_t*>(dptr), nullptr, IH2, IW2, OH2, \
            OW2, op)

    if (need_post_process) {
        tptr = static_cast<int32_t*>(bundle.get(2)) +
               ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size();
        DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS)
    } else {
        DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS)
    }

#undef KERN_NEED_POST_PROCESS
#undef KERN_NO_POST_PROCESS
    if (need_dst_copy_var) {
        rep(oh, OH) {
            std::memcpy(
                    reinterpret_cast<void*>(
                            reinterpret_cast<ptrdiff_t>(dst) +
                            oh * OW * kern_param.dst_type.size()),
                    reinterpret_cast<void*>(
                            reinterpret_cast<ptrdiff_t>(dptr) +
                            oh * OW2 * kern_param.dst_type.size()),
                    kern_param.dst_type.size() * OW);
        }
    }
};
SmallVector<NCBKern> get_kimpls(
        const NCBKernSizeParam& kern_param, const WorkspaceBundle& bundle) {
    MEGDNN_MARK_USED_VAR(kern_param);
    auto fm = kern_param.filter_meta;
    size_t group = fm.group;
    size_t n = kern_param.n;

    SmallVector<NCBKern> ncb_kerns;
    conv_fun do_conv_fun = nullptr;

#define DO_CONV_KERN_FUN(filter, bias_mode, is_quantized, op) \
    do_conv_fun = conv_kimpl<filter, bias_mode, is_quantized, op>;

#define GET_OP_PARAM(i, bias_mode, is_quantized)                                 \
    switch (kern_param.nonlineMode) {                                            \
        case param::ConvBias::NonlineMode::IDENTITY:                             \
            DO_CONV_KERN_FUN(                                                    \
                    i, bias_mode, is_quantized,                                  \
                    TypeCvtOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
                                      dt_qint8>)                                 \
            break;                                                               \
        case param::ConvBias::NonlineMode::RELU:                                 \
            DO_CONV_KERN_FUN(                                                    \
                    i, bias_mode, is_quantized,                                  \
                    ReluOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA    \
                                   dt_qint8>)                                    \
            break;                                                               \
        case param::ConvBias::NonlineMode::H_SWISH:                              \
            DO_CONV_KERN_FUN(                                                    \
                    i, bias_mode, is_quantized,                                  \
                    HSwishOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA  \
                                     dt_qint8>)                                  \
            break;                                                               \
        default:                                                                 \
            megdnn_assert(0);                                                    \
            break;                                                               \
    }

#define GET_BIAS_MODE_PARAM(i, is_quantized)                                \
    switch (kern_param.bias_mode) {                                         \
        case BiasMode::NO_BIAS:                                             \
            GET_OP_PARAM(i, BiasMode::NO_BIAS, is_quantized)                \
            break;                                                          \
        case BiasMode::BROADCAST_CHANNEL_BIAS:                              \
            GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS, is_quantized) \
            break;                                                          \
        default:                                                            \
            megdnn_assert(0);                                               \
            break;                                                          \
    }

#define GET_QUANTIZED(i)                   \
    switch (kern_param.dst_type.enumv()) { \
        case DTypeEnum::QuantizedS8:       \
            GET_BIAS_MODE_PARAM(i, true)   \
            break;                         \
        case DTypeEnum::QuantizedS32:      \
            GET_BIAS_MODE_PARAM(i, false)  \
            break;                         \
        case DTypeEnum::Int32:             \
            GET_BIAS_MODE_PARAM(i, false)  \
            break;                         \
        default:                           \
            megdnn_assert(0);              \
            break;                         \
    }

#define DISPATCH_CONV_KERN()                     \
    switch (kern_param.filter_meta.spatial[0]) { \
        case 2:                                  \
            GET_QUANTIZED(2)                     \
            break;                               \
        case 3:                                  \
            GET_QUANTIZED(3)                     \
            break;                               \
        case 5:                                  \
            GET_QUANTIZED(5)                     \
            break;                               \
        case 7:                                  \
            GET_QUANTIZED(7)                     \
            break;                               \
        default:                                 \
            megdnn_assert(0);                    \
            break;                               \
    }

    DISPATCH_CONV_KERN();

    auto exec_one_group = [bundle = bundle, do_conv_fun](
                                  const NCBKernParam& kern_param,
                                  const NCBKernIndex& ncb_index) mutable {
        bundle.set(kern_param.workspace_ptr);
        copy_padding_kern(bundle, kern_param, ncb_index);
        do_conv_fun(bundle, kern_param, ncb_index);
    };
    ncb_kerns.push_back({exec_one_group, {group, n, 1_z}});

    return ncb_kerns;
}

}  // namespace avx2_chanwise_stride1
}  // namespace x86
}  // namespace megdnn

// vim: syntax=cpp.doxygen