#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/handle.h"
#include "src/cuda/relayout_format/relayout_format.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
void RelayoutFormatImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace ) {
auto src_dtype = src.layout.dtype;
megdnn_assert(
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::NCHW_NCHW4 ||
param().mode == param::RelayoutFormat::Mode::NCHW_NCHW4_WEIGHT ||
param().mode == param::RelayoutFormat::Mode::NCHW4_NCHW ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT ||
param().mode == Param::Mode::NCHW_NCHW64 ||
param().mode == Param::Mode::NCHW64_NCHW ||
param().mode == Param::Mode::NCHW_NHWC ||
param().mode == Param::Mode::NHWC_NCHW,
"relayout format of cuda only support NCHW4->CHWN4 or "
"CHWN4->NCHW4 or NCHW->NCHW4");
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) &&
src_dtype.enumv() == DTypeEnum::QuantizedS8) {
size_t row = 0, col = 0;
if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) {
row = src.layout[0], col = src.layout[1] * src.layout[2] * src.layout[3];
} else {
megdnn_assert(param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4);
row = src.layout[0] * src.layout[1] * src.layout[2], col = src.layout[3];
}
TensorLayout layout_in, layout_out;
layout_in = {{row, col}, dtype::Int32()};
layout_in.init_contiguous_stride();
layout_out = layout_in;
layout_out.stride[0] = 1;
layout_out.stride[1] = row;
TensorND trans_in{src.raw_ptr(), layout_in},
trans_out{dst.raw_ptr(), layout_out};
return handle()->create_operator<RelayoutForward>()->exec(trans_in, trans_out);
}
if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) &&
src.layout[1] % 4 != 0) {
megdnn_assert(
src.raw_ptr() != dst.raw_ptr() && src.layout.ndim == 4,
"The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT "
"of RelayoutFormat opr(cuda backend) does not support "
"src.ptr == dst.ptr");
megdnn_assert(src.layout[1] <= 4);
cuda_check(cudaMemsetAsync(
dst.raw_ptr(), 0, dst.layout.span().dist_byte(),
cuda_stream(this->handle())));
TensorLayout exec_dst_layout = dst.layout;
exec_dst_layout[4] = src.layout[1];
TensorLayout exec_src_layout = src.layout
.reshape(
{src.layout[0], src.layout[1], 1,
src.layout[2], src.layout[3]})
.dimshuffle({0, 2, 3, 4, 1});
return handle()->create_operator<RelayoutForward>()->exec(
{src.raw_ptr(), exec_src_layout}, {dst.raw_ptr(), exec_dst_layout});
}
bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 ||
param().mode == Param::Mode::NCHW64_NCHW ||
param().mode == Param::Mode::NCHW_NHWC ||
param().mode == Param::Mode::NHWC_NCHW) &&
(src_dtype.enumv() == DTypeEnum::QuantizedS4 ||
src_dtype.enumv() == DTypeEnum::Quantized4Asymm);
bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 ||
param().mode == Param::Mode::NCHW4_NCHW ||
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT;
if (is_trans_4bits || is_nchw_nchw4) {
bool is_usable = relayout_format::RelayoutFormatFast::usable(
src.layout, dst.layout, param().mode);
megdnn_assert(
is_usable,
"RelayoutFormatFast kernel is not usable for "
"transforming %s(%s) to %s(%s).",
src.layout.to_string().c_str(), src.layout.dtype.name(),
dst.layout.to_string().c_str(), dst.layout.dtype.name());
return relayout_format::RelayoutFormatFast::exec(
src, dst, cuda_stream(this->handle()), param().mode, param().group);
}
TensorLayout exec_src, exec_dst, exec_workspace;
deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, exec_dst);
TensorND exec_src_nd{src.raw_ptr(), exec_src};
TensorND exec_dst_nd{dst.raw_ptr(), exec_dst};
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd, exec_dst_nd);
}
size_t RelayoutFormatImpl::get_workspace_in_bytes(
const TensorLayout& , const TensorLayout& ) {
return 0;
}