#include "src/fallback/relayout/opr_impl.h"
#include "src/common/relayout_helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cstring>
using namespace megdnn;
using namespace fallback;
namespace {
bool is_lastdim_contig(const TensorLayout& layout) {
return layout.ndim <= 3 && layout.stride[layout.ndim - 1] == 1;
}
template <size_t sz, typename T0 = char>
struct equiv_ctype_storage {
T0 _[sz];
};
template <typename dtype>
struct equiv_ctype {
using type = std::aligned_storage_t<
sizeof(typename DTypeTrait<dtype>::ctype),
alignof(typename DTypeTrait<dtype>::ctype)>;
};
typedef void (*memcpy_policy_t)(void* cont, void* non_cont, size_t);
void memcpy_cont2noncont(void* cont, void* non_cont, size_t size) {
memcpy(non_cont, cont, size);
}
void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) {
memcpy(cont, non_cont, size);
}
template <typename T>
void call_transpose(
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst,
size_t stride_m) {
megdnn_assert(ch == 1);
relayout::transpose_fallback::transpose<T>(
batch, m, n, static_cast<T*>(src), static_cast<T*>(dst), stride_m);
}
template <typename ctype>
void dispatch_on_dtype_cont(
Handle* handle, const TensorND& cont, const TensorND& nonc,
memcpy_policy_t mcp_pol) {
thin_function<void()> kern;
switch (nonc.layout.ndim) {
case 2: {
auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1];
auto strd0_n = nonc.layout.stride[0] * sizeof(ctype);
auto strd0_c = shp1 * sizeof(ctype);
kern = [=]() {
auto cur_ctptr = static_cast<uint8_t*>(cont.raw_ptr());
auto cur_ncptr = static_cast<uint8_t*>(nonc.raw_ptr());
for (size_t i = 0; i < shp0; ++i) {
mcp_pol(cur_ctptr, cur_ncptr, strd0_c);
cur_ctptr += strd0_c;
cur_ncptr += strd0_n;
}
};
break;
}
case 3: {
auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1],
shp2 = nonc.layout.shape[2];
auto strd0_n = nonc.layout.stride[0] * sizeof(ctype),
strd1_n = nonc.layout.stride[1] * sizeof(ctype);
auto strd1_c = shp2 * sizeof(ctype);
kern = [=]() {
auto cur_ctptr = static_cast<uint8_t*>(cont.raw_ptr());
auto ncptr_row = static_cast<uint8_t*>(nonc.raw_ptr());
for (size_t i = 0; i < shp0; ++i) {
auto cur_ncptr = ncptr_row;
for (size_t j = 0; j < shp1; ++j) {
mcp_pol(cur_ctptr, cur_ncptr, strd1_c);
cur_ctptr += strd1_c;
cur_ncptr += strd1_n;
}
ncptr_row += strd0_n;
}
};
break;
}
default:
megdnn_assert(0);
}
static_cast<naive::HandleImpl*>(handle)->dispatch_kern(std::move(kern));
}
void dispatch_cont(
Handle* handle, const TensorND& cont, const TensorND& nonc,
memcpy_policy_t mcp_pol) {
switch (cont.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<dtype::_dt>::enumv: \
return dispatch_on_dtype_cont<equiv_ctype<dtype::_dt>::type>( \
handle, cont, nonc, mcp_pol);
MEGDNN_FOREACH_DTYPE_NAME(cb)
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb)
#undef cb
megdnn_assert(0);
}
}
const size_t BLOCK_SIZE = 16,
TRANSPOSE_CV_MAX_C = relayout::transpose_fallback::BLOCK_LINE_SIZE_BYTES;
template <typename ctype>
void transpose_cv_block(
size_t m, size_t n, size_t ch, size_t i, size_t j, size_t h, size_t w,
void* src, void* dst) {
auto batch_src = static_cast<const ctype*>(src);
auto batch_dst = static_cast<ctype*>(dst);
#define SET_VAL(dst, src) \
switch (ch) { \
case 3: \
dst[2] = src[2]; \
MEGDNN_FALLTHRU \
case 2: \
dst[1] = src[1]; \
MEGDNN_FALLTHRU \
case 1: \
dst[0] = src[0]; \
break; \
default: \
for (size_t _c = 0; _c < ch; ++_c) \
dst[_c] = src[_c]; \
break; \
}
constexpr size_t B = BLOCK_SIZE;
static_assert(TRANSPOSE_CV_MAX_C % sizeof(ctype) == 0, "bad ctype");
ctype tmp[B][B][TRANSPOSE_CV_MAX_C / sizeof(ctype)];
auto sptr = batch_src + i * n * ch + j * ch;
for (size_t x = 0; x < h; ++x) {
for (size_t y = 0; y < w; ++y) {
SET_VAL(tmp[y][x], (sptr + y * ch))
}
sptr += n * ch;
}
auto dptr = batch_dst + j * m * ch + i * ch;
for (size_t x = 0; x < w; ++x) {
for (size_t y = 0; y < h; ++y) {
SET_VAL((dptr + y * ch), tmp[x][y])
}
dptr += m * ch;
}
#undef SET_VAL
}
template <typename ctype>
void transpose_cv_row(
size_t m, size_t n, size_t ch, size_t i, size_t h, void* src, void* dst) {
constexpr size_t B = BLOCK_SIZE;
size_t j = 0;
for (; j + B <= n; j += B) {
transpose_cv_block<ctype>(m, n, ch, i, j, h, B, src, dst);
}
if (j < n) {
transpose_cv_block<ctype>(m, n, ch, i, j, h, n - j, src, dst);
}
}
template <typename ctype>
void transpose_cv(
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst,
size_t stride_m) {
megdnn_assert(stride_m == 0);
constexpr size_t B = BLOCK_SIZE;
auto batch_src = static_cast<ctype*>(src);
auto batch_dst = static_cast<ctype*>(dst);
for (size_t b = 0; b < batch; ++b) {
size_t i = 0;
for (; i + B <= m; i += B) {
transpose_cv_row<ctype>(m, n, ch, i, B, batch_src, batch_dst);
}
if (i < m) {
transpose_cv_row<ctype>(m, n, ch, i, m - i, batch_src, batch_dst);
}
batch_src += m * n * ch;
batch_dst += m * n * ch;
}
}
}
void RelayoutForwardImpl::exec(
_megdnn_tensor_in src0, _megdnn_tensor_out dst0, Handle* src_handle) {
check_cpu_handle(src_handle);
TensorND src = src0, dst = dst0;
check_layout_and_canonize(src.layout, dst.layout);
bool has_neg_stride = false;
for (size_t i = 0; i < src.layout.ndim; ++i) {
if (src.layout.stride[i] < 0) {
has_neg_stride = true;
break;
}
}
for (size_t i = 0; i < dst.layout.ndim; ++i) {
if (dst.layout.stride[i] < 0) {
has_neg_stride = true;
break;
}
}
if (has_neg_stride) {
NaiveRelayoutForwardImpl::do_exec(src, dst);
return;
}
if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
NaiveRelayoutForwardImpl::do_exec(src, dst);
return;
}
relayout::TransposeParam trans_param;
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true);
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr);
}
void RelayoutForwardImpl::exec_after_preprocess(
const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) {
if (transpose) {
auto kernel = [tparam = *transpose, src, dst]() {
auto t = tparam;
auto dsize = src.layout.dtype.size() * t.c;
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) =
nullptr;
auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr()),
dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr());
if (dsize == 1) {
megdnn_assert(t.c == 1);
kptr = call_transpose<uint8_t>;
} else if (dsize == 2) {
t.c = 1;
if (!((src_addr | dst_addr) & (alignof(uint16_t) - 1))) {
kptr = call_transpose<uint16_t>;
} else {
kptr = call_transpose<equiv_ctype_storage<2>>;
megdnn_log_error("unaligned addr in relayout");
}
} else if (dsize == 3) {
t.c = 1;
kptr = call_transpose<equiv_ctype_storage<3>>;
} else if (dsize == 4) {
t.c = 1;
if (!((src_addr | dst_addr) & (alignof(uint32_t) - 1))) {
kptr = call_transpose<uint32_t>;
} else {
kptr = call_transpose<equiv_ctype_storage<4>>;
megdnn_log_error("unaligned addr in relayout");
}
} else if (dsize == 12) {
t.c = 1;
if (!((src_addr | dst_addr) & (alignof(uint32_t) - 1))) {
kptr = call_transpose<equiv_ctype_storage<3, uint32_t>>;
} else {
kptr = call_transpose<equiv_ctype_storage<12>>;
megdnn_log_error("unaligned addr in relayout");
}
} else if (dsize <= TRANSPOSE_CV_MAX_C) {
switch (dst.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<dtype::_dt>::enumv: \
kptr = transpose_cv<equiv_ctype<dtype::_dt>::type>; \
break;
MEGDNN_FOREACH_DTYPE_NAME(cb)
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb)
#undef cb
}
megdnn_assert(kptr);
}
if (kptr) {
auto sptr = src.raw_ptr();
auto dptr = dst.raw_ptr();
kptr(t.batch, t.m, t.n, t.c, sptr, dptr, t.stride_m);
return;
} else {
megdnn_assert(t.c != 1, "unsupported dtype size");
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kernel());
}
using relayout::is_contig;
if (is_contig(dst.layout) && is_contig(src.layout)) {
auto sz = src.layout.span().dist_byte();
MEGDNN_DISPATCH_CPU_KERN_OPR(memcpy(dst.raw_ptr(), src.raw_ptr(), sz));
return;
}
if (is_contig(dst.layout) && is_lastdim_contig(src.layout)) {
return dispatch_cont(handle(), dst, src, memcpy_noncont2cont);
}
if (is_contig(src.layout) && is_lastdim_contig(dst.layout)) {
return dispatch_cont(handle(), src, dst, memcpy_cont2noncont);
}
NaiveRelayoutForwardImpl::do_exec(src, dst);
}