#include "src/cuda/warp_perspective/opr_impl.h"
#include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
#include "src/cuda/utils.h"
#include "src/cuda/warp_perspective/common.h"
#include "src/cuda/warp_perspective/helper.h"
#include "src/common/cv/common.h"
#include "src/common/warp_common.h"
namespace megdnn {
namespace cuda {
namespace {
inline void deduce_reformat_layout(
std::unique_ptr<RelayoutFormat>& relayout, const TensorLayout& src_layout,
TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc = 0,
const int group = 1) {
if (src_layout.ndim > 0) {
RelayoutFormat::Param trans_param;
trans_param.mode = mode;
trans_param.oc = oc;
trans_param.group = group;
relayout->param() = trans_param;
relayout->deduce_layout(src_layout, dst_layout);
} else {
dst_layout = src_layout;
}
}
void get_inner_layout(
const TensorLayout& src, const TensorLayout& dst, TensorLayout& inner_src,
TensorLayout& inner_dst, Handle* handle,
WarpPerspectiveForwardImpl::Param::Format format) {
if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
dst.dtype.enumv() == src.dtype.enumv() &&
format == param::WarpPerspective::Format::NCHW) {
auto relayout_opr = handle->create_operator<RelayoutFormat>();
deduce_reformat_layout(
relayout_opr, src, inner_src, RelayoutFormat::Param::Mode::NCHW_NCHW64,
0, 1);
deduce_reformat_layout(
relayout_opr, dst, inner_dst, RelayoutFormat::Param::Mode::NCHW_NCHW64,
0, 1);
} else {
megdnn_assert(0, "not support");
}
}
}
namespace warp_perspective {
void warp_perspective_cv_exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in dst,
float border_val, BorderMode bmode, InterpolationMode imode,
_megdnn_workspace workspace, cudaStream_t stream) {
megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, "unsupported src channel");
megdnn_assert(
src.layout.dtype != dtype::Float32() || src.layout.dtype != dtype::Uint8(),
"unsupported src dtype");
if (imode == InterpolationMode::INTER_AREA) {
imode = InterpolationMode::INTER_LINEAR;
}
using namespace megcv;
const float* trans_ptr = mat.ptr<dt_float32>();
double* workspace_ptr = workspace.ptr<double>();
for (size_t i = 0; i < src.layout.shape[0]; ++i) {
if (dst.layout.dtype == dtype::Float32()) {
Mat<float> src_mat = TensorND2Mat<float>(src, i);
Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
if (src_mat.channels() == 1) {
warp_perspective_cv_proxy<float, 1>(
src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
bmode, imode, trans_ptr, border_val, workspace_ptr, stream);
} else {
warp_perspective_cv_proxy<float, 3>(
src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
bmode, imode, trans_ptr, border_val, workspace_ptr, stream);
}
} else if (dst.layout.dtype == dtype::Uint8()) {
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
if (src_mat.channels() == 1) {
warp_perspective_cv_proxy<uchar, 1>(
src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
bmode, imode, trans_ptr, static_cast<uchar>(border_val),
workspace_ptr, stream);
} else {
warp_perspective_cv_proxy<uchar, 3>(
src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
bmode, imode, trans_ptr, static_cast<uchar>(border_val),
workspace_ptr, stream);
}
} else {
megdnn_throw("Unsupported datatype of WarpPerspective optr.");
}
trans_ptr += 3 * 3;
workspace_ptr += 3 * 3;
}
}
}
WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
void* ptr, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) const {
MEGDNN_MARK_USED_VAR(mat_idx);
SmallVector<size_t> sizes;
TensorLayout fsrc = src;
TensorLayout fmat = mat;
TensorLayout fdst = dst;
if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
param().format == param::WarpPerspective::Format::NCHW) {
get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
sizes.push_back(fsrc.span().dist_byte());
sizes.push_back(fdst.span().dist_byte());
} else {
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
sizes.push_back(layout.span().dist_byte());
}
};
get_workspace(fsrc);
get_workspace(fmat);
get_workspace(fdst);
}
if (param().format == param::WarpPerspective::Format::NHWC) {
sizes.push_back(mat.total_nr_elems() * sizeof(double));
}
return {ptr, std::move(sizes)};
}
void WarpPerspectiveForwardImpl::exec(
_megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx,
_megdnn_tensor_out sdst, _megdnn_workspace sworkspace) {
check_exec_allow_nhwc_mat_idx(
ssrc.layout, smat.layout, smat_idx.layout, sdst.layout, sworkspace.size);
TensorND src = ssrc;
TensorND mat = smat;
TensorND mat_idx = smat_idx;
TensorND dst = sdst;
Param::Format inner_format = param().format;
auto bundle = get_workspace_bundle(
sworkspace.raw_ptr, ssrc.layout, smat.layout, smat_idx.layout, sdst.layout);
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
concrete_handle(this->handle()), &bundle);
if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
ctypecvt.src_to_comp_type(ssrc, src)
.src_to_comp_type(smat, mat)
.src_to_comp_type(sdst, dst);
} else if (
(ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
param().format == Param::Format::NCHW) {
auto handle_ptr = handle();
get_inner_layout(
ssrc.layout, sdst.layout, src.layout, dst.layout, handle_ptr,
param().format);
src = TensorND{bundle.get(0), src.layout};
dst = TensorND{bundle.get(1), dst.layout};
auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
relayout_opr->param() = trans_param;
relayout_opr->exec(ssrc, src, {});
inner_format = Param::Format::NCHW64;
}
{
auto stream = cuda_stream(this->handle());
bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) {
megdnn_assert(
!mat_idx.raw_ptr(),
"mat_idx is not supported in NHWC case with "
"non-linear interpolation");
warp_perspective::warp_perspective_cv_exec(
src, mat, dst, param().border_val,
warp_perspective::get_bmode(param().bmode),
warp_perspective::get_imode(param().imode), ctypecvt.workspace(),
stream);
} else {
megdnn_assert(warp::is_dnn_available(
src.layout, mat.layout, dst.layout, param().imode, inner_format));
size_t C, IH, IW, OH, OW;
if (is_nhwc) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[1];
OW = dst.layout.shape[2];
} else if (inner_format == Param::Format::NCHW4) {
C = src.layout.shape[1] * 4;
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else if (inner_format == Param::Format::NHWC_NCHW) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
megdnn_assert(
(C == 1) || (C == 3),
"NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
} else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) {
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
megdnn_assert(
(C == 1) || (C == 3),
"NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
} else if (inner_format == Param::Format::NCHW64) {
C = src.layout.shape[1] * 64;
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else {
megdnn_assert(
inner_format == param::WarpPerspective::Format::NCHW,
"invalid warp_perspective format");
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
}
megdnn_assert(
param().imode == Param::InterpolationMode::LINEAR,
"unsupported interpolation mode for NCHW format");
auto bval = param().border_val;
auto bmode = warp_perspective::get_bmode(param().bmode);
if (src.layout.dtype == dst.layout.dtype) {
if (src.layout.dtype == dtype::Float32{}) {
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float32>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float32>(), src.layout[0], mat.layout[0], C, IH,
IW, OH, OW, bval, bmode, async_error_info(handle()),
m_error_tracker, stream);
} else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) {
#ifndef MEGDNN_DISABLE_FLOAT16
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float16>(), src.layout[0], mat.layout[0], C, IH,
IW, OH, OW, static_cast<dt_float16>(bval), bmode,
async_error_info(handle()), m_error_tracker, stream);
#endif
} else if (src.layout.dtype == dtype::Uint8()) {
warp_perspective::forward_proxy<dt_uint8>(
is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0], C, IH,
IW, OH, OW, bval, bmode, async_error_info(handle()),
m_error_tracker, stream);
} else if (src.layout.dtype == dtype::Int8()) {
megdnn_assert(
!is_nhwc,
"WarpPerspective on CUDA does not support "
"NHWC + Int8");
warp_perspective::forward_proxy<dt_int8>(
false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C, IH, IW,
OH, OW, bval
,
bmode, async_error_info(handle()), m_error_tracker, stream);
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
megdnn_assert(
param().format == Param::Format::NCHW4,
"WarpPerspective on CUDA supports NCHW4 + "
"QuantizedS8 only");
warp_perspective::forward_proxy_nchw4<dt_int8>(
src.compatible_ptr<dt_int8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_int8>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()),
m_error_tracker, stream);
} else if (
(src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) &&
(param().format == Param::Format::NCHW64 ||
param().format == Param::Format::NCHW)) {
bval = roundf(bval);
bval = fmin(fmax(-8.f, bval), 7.f);
warp_perspective::forward_proxy_nchw64<dt_qint4>(
src.compatible_ptr<dt_qint4>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_qint4>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW,
static_cast<dt_qint4>(bval), bmode,
async_error_info(handle()), m_error_tracker, stream);
if (param().format == Param::Format::NCHW) {
auto relayout_opr = handle()->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
trans_param.oc = sdst.layout[1];
relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst, {});
}
} else if (
(src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
(param().format == Param::Format::NCHW64 ||
param().format == Param::Format::NCHW)) {
bval = roundf(bval);
bval = fmin(fmax(0, bval), 15);
warp_perspective::forward_proxy_nchw64<dt_quint4>(
src.compatible_ptr<dt_quint4>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_quint4>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW,
static_cast<dt_quint4>(bval), bmode,
async_error_info(handle()), m_error_tracker, stream);
if (param().format == Param::Format::NCHW) {
auto relayout_opr = handle()->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
trans_param.oc = sdst.layout[1];
relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst, {});
}
} else if (
(src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
(param().format == Param::Format::NHWC)) {
constexpr int pack_c = 8;
megdnn_assert(C % pack_c == 0);
bval = roundf(bval);
if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
bval = fmin(fmax(-8.f, bval), 7.f);
if (C % 16 == 0) {
warp_perspective::forward_proxy_nhwc_bit4<dt_qint4, 16>(
src.ptr<dt_qint4>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_qint4>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, static_cast<dt_qint4>(bval),
bmode, async_error_info(handle()), m_error_tracker,
stream);
} else {
warp_perspective::forward_proxy_nhwc_bit4<dt_qint4, pack_c>(
src.ptr<dt_qint4>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_qint4>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, static_cast<dt_qint4>(bval),
bmode, async_error_info(handle()), m_error_tracker,
stream);
}
} else {
bval = fmin(fmax(0.f, bval), 15.f);
if (C % 16 == 0) {
warp_perspective::forward_proxy_nhwc_bit4<dt_quint4, 16>(
src.ptr<dt_quint4>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_quint4>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, static_cast<dt_quint4>(bval),
bmode, async_error_info(handle()), m_error_tracker,
stream);
} else {
warp_perspective::forward_proxy_nhwc_bit4<
dt_quint4, pack_c>(
src.ptr<dt_quint4>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_quint4>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, static_cast<dt_quint4>(bval),
bmode, async_error_info(handle()), m_error_tracker,
stream);
}
}
}
} else if ((src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
src.layout.dtype.enumv() == DTypeEnum::Uint8)) {
uint8_t zero_point = 0;
float scale = 1.f;
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
zero_point =
src.layout.dtype.param<dtype::Quantized8Asymm>().zero_point;
scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale;
} else if (
src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
zero_point = 128;
scale = 1.f;
}
DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point);
if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst.layout.dtype.param<dtype::QuantizedS8>().scale == scale) &&
((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) ||
(param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) {
bool is_nhwc_ic_small =
(param().format == Param::Format::NHWC_NCHW4_IC_SMALL);
warp_perspective::forward_proxy_quint8_dimshuffle_typecvt_nchw4<
dt_quint8, dt_uint8, dt_int8>(
is_nhwc_ic_small, src.compatible_ptr<dt_uint8>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_int8>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, bval, src_dtype_param, bmode,
async_error_info(handle()), m_error_tracker, stream);
} else {
megdnn_assert(
((dst.layout.dtype.enumv() == DTypeEnum::Float32) &&
((param().format == Param::Format::NCHW) ||
(param().format == Param::Format::NHWC_NCHW))),
"invalid format for Quantized8Asymm input");
bool is_nhwc = (param().format == Param::Format::NHWC_NCHW);
warp_perspective::forward_proxy_quint8_dimshuffle_typecvt_nchw<
dt_quint8, dt_uint8, dt_float32>(
is_nhwc, src.compatible_ptr<dt_uint8>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_float32>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW, bval, src_dtype_param,
bmode, async_error_info(handle()), m_error_tracker, stream);
}
} else {
megdnn_throw(
ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
}
}
}
if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
ctypecvt.comp_to_dst_type(dst, sdst);
}
}
} }