/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are
*permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this
*list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this
*list of conditions and the following disclaimer in the documentation and/or other
*materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors
*may be used to endorse or promote products derived from this software without specific
*prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
*EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
*OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
*SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
*EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
*HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
*SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh"
#include "src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh"
#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh"
#include "src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh"
#include "src/cuda/convolution_helper/layout.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
namespace megdnn {
namespace cuda {
namespace convolution {
#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \
_src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, _kern_layout, \
_output_layout, _conv_param) \
using src_dtype = _src_dtype; \
using filter_dtype = _filter_dtype; \
using smem_storage_dtype = _smem_storage_dtype; \
using InputLayout = _input_layout; \
using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \
using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_;
#define MEGDNN_COMMA ,
template <
bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
typename ThreadConfig_>
struct IConvTrait {
COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(
int8_t, int8_t, int32_t, Layout<Format::CHWN4>, Layout<Format::CHWN4>,
Layout<Format::CHWN4>, ConvParam);
using RegBlockConfig = RegBlockConfig_;
using ThreadConfig = ThreadConfig_;
struct DataTileCount {
using RegBlockConfig = RegBlockConfig;
using ThreadConfig = ThreadConfig;
using copy_t = ldg_dtype;
using smem_storage_dtype = smem_storage_dtype;
static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype);
static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype);
static int constexpr skew = load_width;
static int constexpr block_tile_batch =
RegBlockConfig::reg_n * ThreadConfig::nr_thread_x;
static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;
static int constexpr smem_load_x = block_tile_batch / load_width;
static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
static int constexpr load_y = ThreadConfig::nr_threads / load_x;
static int constexpr smem_h = RegBlockConfig::reg_k_packed;
static int constexpr smem_w = block_tile_batch;
static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w;
static int constexpr smem_tot = smem_h * smem_stride;
static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;
static bool constexpr check_bounds_h = smem_h % load_y != 0;
static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
};
struct FilterTileCount {
using RegBlockConfig = RegBlockConfig;
using ThreadConfig = ThreadConfig;
using copy_t = ldg_dtype;
using smem_storage_dtype = smem_storage_dtype;
static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype);
static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype);
static int constexpr skew = load_width;
static int constexpr block_tile_out_channel =
RegBlockConfig::reg_m * ThreadConfig::nr_thread_y;
static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;
static int constexpr smem_load_x = block_tile_out_channel / load_width;
static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
static int constexpr load_y = ThreadConfig::nr_threads / load_x;
static int constexpr smem_h = RegBlockConfig::reg_k_packed;
static int constexpr smem_w = block_tile_out_channel;
static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w;
static int constexpr smem_tot = smem_h * smem_stride;
static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;
static bool constexpr check_bounds_h = smem_h % load_y != 0;
static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
};
using BlockTileIterator = BlockTileIteratorBasic<DataTileCount, FilterTileCount>;
using DataGlobal2ShareMemVisitor =
Global2ShareMemVisitor_CIxN<check_bounds, DataTileCount, InputLayout>;
using FilterGlobal2ShareMemVisitor =
Global2ShareMemVisitor_CIxN<check_bounds, FilterTileCount, KernLayout>;
static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1;
using BlockConsumer = IConvBlockConsumer<RegBlockConfig, ThreadConfig, pipelined>;
using GlobalMemoryWriter = IConvGlobalMemoryWriter<RegBlockConfig, ThreadConfig>;
};
template <
bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
typename ThreadConfig_>
struct IConvTraitUnrollWidth {
COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(
int8_t, int8_t, int32_t, Layout<Format::CHWN4>, Layout<Format::CHWN4>,
Layout<Format::CHWN4>, ConvParam);
using RegBlockConfig = RegBlockConfig_;
using ThreadConfig = ThreadConfig_;
struct DataTileCount {
using RegBlockConfig = RegBlockConfig;
using ThreadConfig = ThreadConfig;
using copy_t = ldg_dtype;
using smem_storage_dtype = smem_storage_dtype;
static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype);
static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype);
static int constexpr skew = load_width;
static int constexpr block_tile_batch =
RegBlockConfig::reg_n * ThreadConfig::nr_thread_x;
static int constexpr block_tile_out_width = RegBlockConfig::reg_width;
static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;
static int constexpr smem_load_x = block_tile_batch / load_width;
static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
static int constexpr load_y = ThreadConfig::nr_threads / load_x;
static int constexpr smem_h = RegBlockConfig::reg_k_packed;
static int constexpr smem_w = block_tile_batch;
static int constexpr img_cache = RegBlockConfig::reg_width;
static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w;
static int constexpr smem_tot = smem_h * img_cache * smem_stride;
static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;
static bool constexpr check_bounds_h = smem_h % load_y != 0;
static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
};
MEGDNN_STATIC_ASSERT(
std::is_same<typename IConvTrait<
check_bounds MEGDNN_COMMA ldg_dtype MEGDNN_COMMA
RegBlockConfig MEGDNN_COMMA ThreadConfig>::
filter_dtype MEGDNN_COMMA filter_dtype>::value == true,
"data type of filter tensor should be int8_t");
using FilterTileCount = typename IConvTrait<
check_bounds, ldg_dtype, RegBlockConfig, ThreadConfig>::FilterTileCount;
using BlockTileIterator =
BlockTileIteratorUnrollWidth<DataTileCount, FilterTileCount>;
using DataGlobal2ShareMemVisitor =
Global2ShareMemVisitor_CIxWOxN<check_bounds, DataTileCount, InputLayout>;
using FilterGlobal2ShareMemVisitor =
Global2ShareMemVisitor_CIxN<check_bounds, FilterTileCount, KernLayout>;
static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1;
using BlockConsumer =
IConvBlockConsumerUnrollWidth<RegBlockConfig, ThreadConfig, pipelined>;
using GlobalMemoryWriter =
IConvGlobalMemoryWriterUnrollWidth<RegBlockConfig, ThreadConfig>;
};
#undef COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM
#undef MEGDNN_COMMA
} // namespace convolution
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cuda.doxygen foldmethod=marker foldmarker=f{{{,f}}}