megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/***************************************************************************************************
 * Copyright (c) 2017-2019, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification, are
 *permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice, this
 *list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright notice, this
 *list of conditions and the following disclaimer in the documentation and/or other
 *materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors
 *may be used to endorse or promote products derived from this software without specific
 *prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
 *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
 *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/**
 * \file dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh"
#include "src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh"
#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh"
#include "src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh"
#include "src/cuda/convolution_helper/layout.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"

namespace megdnn {
namespace cuda {
namespace convolution {
#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(                                 \
        _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, _kern_layout, \
        _output_layout, _conv_param)                                                 \
    using src_dtype = _src_dtype;                                                    \
    using filter_dtype = _filter_dtype;                                              \
    using smem_storage_dtype = _smem_storage_dtype;                                  \
    using InputLayout = _input_layout;                                               \
    using KernLayout = _kern_layout;                                                 \
    using OutputLayout = _output_layout;                                             \
    using Param = _conv_param;                                                       \
    static constexpr bool check_bounds = check_bounds_;
#define MEGDNN_COMMA ,

template <
        bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
        typename ThreadConfig_>
struct IConvTrait {
    COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(
            int8_t, int8_t, int32_t, Layout<Format::CHWN4>, Layout<Format::CHWN4>,
            Layout<Format::CHWN4>, ConvParam);
    using RegBlockConfig = RegBlockConfig_;
    using ThreadConfig = ThreadConfig_;
    struct DataTileCount {
        using RegBlockConfig = RegBlockConfig;
        using ThreadConfig = ThreadConfig;
        using copy_t = ldg_dtype;
        using smem_storage_dtype = smem_storage_dtype;
        static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype);
        static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype);
        static int constexpr skew = load_width;
        static int constexpr block_tile_batch =
                RegBlockConfig::reg_n * ThreadConfig::nr_thread_x;
        static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;

        static int constexpr smem_load_x = block_tile_batch / load_width;
        static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
        static int constexpr load_y = ThreadConfig::nr_threads / load_x;

        static int constexpr smem_h = RegBlockConfig::reg_k_packed;
        static int constexpr smem_w = block_tile_batch;
        static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w;
        static int constexpr smem_tot = smem_h * smem_stride;

        static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
        static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;

        static bool constexpr check_bounds_h = smem_h % load_y != 0;
        static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
    };

    struct FilterTileCount {
        using RegBlockConfig = RegBlockConfig;
        using ThreadConfig = ThreadConfig;
        using copy_t = ldg_dtype;
        using smem_storage_dtype = smem_storage_dtype;
        static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype);
        static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype);
        static int constexpr skew = load_width;
        static int constexpr block_tile_out_channel =
                RegBlockConfig::reg_m * ThreadConfig::nr_thread_y;
        static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;

        static int constexpr smem_load_x = block_tile_out_channel / load_width;
        static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
        static int constexpr load_y = ThreadConfig::nr_threads / load_x;

        static int constexpr smem_h = RegBlockConfig::reg_k_packed;
        static int constexpr smem_w = block_tile_out_channel;
        static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w;
        static int constexpr smem_tot = smem_h * smem_stride;

        static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
        static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;

        static bool constexpr check_bounds_h = smem_h % load_y != 0;
        static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
    };

    using BlockTileIterator = BlockTileIteratorBasic<DataTileCount, FilterTileCount>;
    using DataGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxN<check_bounds, DataTileCount, InputLayout>;
    using FilterGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxN<check_bounds, FilterTileCount, KernLayout>;
    static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1;
    using BlockConsumer = IConvBlockConsumer<RegBlockConfig, ThreadConfig, pipelined>;
    using GlobalMemoryWriter = IConvGlobalMemoryWriter<RegBlockConfig, ThreadConfig>;
};

template <
        bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
        typename ThreadConfig_>
struct IConvTraitUnrollWidth {
    COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(
            int8_t, int8_t, int32_t, Layout<Format::CHWN4>, Layout<Format::CHWN4>,
            Layout<Format::CHWN4>, ConvParam);
    using RegBlockConfig = RegBlockConfig_;
    using ThreadConfig = ThreadConfig_;
    struct DataTileCount {
        using RegBlockConfig = RegBlockConfig;
        using ThreadConfig = ThreadConfig;
        using copy_t = ldg_dtype;
        using smem_storage_dtype = smem_storage_dtype;
        static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype);
        static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype);
        static int constexpr skew = load_width;
        static int constexpr block_tile_batch =
                RegBlockConfig::reg_n * ThreadConfig::nr_thread_x;
        static int constexpr block_tile_out_width = RegBlockConfig::reg_width;
        static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;

        static int constexpr smem_load_x = block_tile_batch / load_width;
        static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
        static int constexpr load_y = ThreadConfig::nr_threads / load_x;

        static int constexpr smem_h = RegBlockConfig::reg_k_packed;
        static int constexpr smem_w = block_tile_batch;
        static int constexpr img_cache = RegBlockConfig::reg_width;
        static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w;
        static int constexpr smem_tot = smem_h * img_cache * smem_stride;

        static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
        static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;

        static bool constexpr check_bounds_h = smem_h % load_y != 0;
        static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
    };
    MEGDNN_STATIC_ASSERT(
            std::is_same<typename IConvTrait<
                            check_bounds MEGDNN_COMMA ldg_dtype MEGDNN_COMMA
                                    RegBlockConfig MEGDNN_COMMA ThreadConfig>::
                                 filter_dtype MEGDNN_COMMA filter_dtype>::value == true,
            "data type of filter tensor should be int8_t");
    using FilterTileCount = typename IConvTrait<
            check_bounds, ldg_dtype, RegBlockConfig, ThreadConfig>::FilterTileCount;
    using BlockTileIterator =
            BlockTileIteratorUnrollWidth<DataTileCount, FilterTileCount>;
    using DataGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxWOxN<check_bounds, DataTileCount, InputLayout>;
    using FilterGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxN<check_bounds, FilterTileCount, KernLayout>;
    static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1;
    using BlockConsumer =
            IConvBlockConsumerUnrollWidth<RegBlockConfig, ThreadConfig, pipelined>;
    using GlobalMemoryWriter =
            IConvGlobalMemoryWriterUnrollWidth<RegBlockConfig, ThreadConfig>;
};

#undef COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM
#undef MEGDNN_COMMA

}  // namespace convolution
}  // namespace cuda
}  // namespace megdnn

// vim: ft=cpp syntax=cuda.doxygen foldmethod=marker foldmarker=f{{{,f}}}