#include "src/arm_common/resize/opr_impl.h"
#include "src/arm_common/handle.h"
#include "src/arm_common/resize/direct_nchwxx.h"
#include "src/arm_common/resize/resize_cv.h"
#include "src/arm_common/resize/upsample2_nchw.h"
#include "src/arm_common/resize/upsample2_nchwxx.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_resize)
namespace megdnn {
namespace arm_common {
void ResizeImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous();
bool is_dtype_same = src.layout.dtype == dst.layout.dtype;
bool is_dtype_fp32 = src.layout.dtype == dtype::Float32();
bool is_dtype_fp16 =
DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false);
bool is_dtype_supported = is_dtype_same && (is_dtype_fp32 || is_dtype_fp16);
bool is_nchw = param().format == param::Resize::Format::NCHW &&
(is_dtype_fp32 || is_dtype_fp16);
bool is_nchw44_fp32 =
param().format == param::Resize::Format::NCHW44 && is_dtype_fp32;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
bool is_nchw88_fp16 =
param().format == param::Resize::Format::NCHW88 && is_dtype_fp16;
#endif
bool is_imode_nearest =
param().imode == param::Resize::InterpolationMode::INTER_NEAREST;
bool is_imode_linear =
param().imode == param::Resize::InterpolationMode::INTER_LINEAR;
bool is_imode_supported = is_imode_nearest || is_imode_linear;
bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] &&
src.layout.shape[3] * 2 == dst.layout.shape[3];
bool usable = is_contiguous && is_dtype_supported && is_imode_supported;
if (param().format == param::Resize::Format::NHWC &&
(src.layout[3] == 1 || src.layout[3] == 3) && is_nhwc_contig_wc(src.layout)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode));
} else if (!usable) {
fallback::ResizeImpl::exec(src, dst, workspace);
} else if (is_dtype_fp32) {
auto kern_param = KernParam<float>::from_tensors(
param().format, param().imode, src, dst, workspace);
if (is_nchw44_fp32) {
if (is_upsample2) {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(0)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw44_fp32(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(1)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw44_fp32(kern_param));
}
MIDOUT_END();
}
} else {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(2)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_nearest_nchw44_fp32(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(3)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_linear_nchw44_fp32(kern_param));
}
MIDOUT_END();
}
}
} else if (is_nchw) {
if (is_upsample2) {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(4)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw_fp32(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(5)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw_fp32(kern_param));
}
MIDOUT_END();
}
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (is_dtype_fp16) {
auto kern_param = KernParam<dt_float16>::from_tensors(
param().format, param().imode, src, dst, workspace);
if (is_nchw88_fp16) {
if (is_upsample2) {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(6)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw88_fp16(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(7)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw88_fp16(kern_param));
}
MIDOUT_END();
}
} else {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(8)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_nearest_nchw88_fp16(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(9)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_linear_nchw88_fp16(kern_param));
}
MIDOUT_END();
}
}
} else if (is_nchw) {
if (is_upsample2) {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(10)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw_fp16(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(11)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw_fp16(kern_param));
}
MIDOUT_END();
}
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
#endif
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
}
} }