#include "src/common/utils.h"
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"
using namespace megdnn;
namespace {
template <typename dtype>
void transpose_naive(const dtype* src, dtype* dst, int lda, int ldb, int n, int m) {
rep(i, n) rep(j, m) { dst[i * ldb + j] = src[j * lda + i]; }
}
void transpose_4x4_neon(const float* src, float* dst, int lda, int ldb) {
float32x4x2_t a0, a1;
a0.val[0] = vld1q_f32(src + 0 * lda);
a0.val[1] = vld1q_f32(src + 1 * lda);
a1.val[0] = vld1q_f32(src + 2 * lda);
a1.val[1] = vld1q_f32(src + 3 * lda);
float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]);
float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]);
float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]);
float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]);
vst1q_f32(dst + 0 * ldb, c0.val[0]);
vst1q_f32(dst + 1 * ldb, c0.val[1]);
vst1q_f32(dst + 2 * ldb, c1.val[0]);
vst1q_f32(dst + 3 * ldb, c1.val[1]);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void transpose_8x8_neon(const dt_float16* src, dt_float16* dst, int lda, int ldb) {
const __fp16* src_ptr = reinterpret_cast<const __fp16*>(src);
__fp16* dst_ptr = reinterpret_cast<__fp16*>(dst);
float16x8x4_t a0, a1;
a0.val[0] = vld1q_f16(src_ptr + 0 * lda); a0.val[1] = vld1q_f16(src_ptr + 1 * lda); a0.val[2] = vld1q_f16(src_ptr + 2 * lda); a0.val[3] = vld1q_f16(src_ptr + 3 * lda); a1.val[0] = vld1q_f16(src_ptr + 4 * lda); a1.val[1] = vld1q_f16(src_ptr + 5 * lda); a1.val[2] = vld1q_f16(src_ptr + 6 * lda); a1.val[3] = vld1q_f16(src_ptr + 7 * lda);
float16x8x2_t b0 =
vzipq_f16(a0.val[0], a1.val[0]); float16x8x2_t b1 =
vzipq_f16(a0.val[2], a1.val[2]); float16x8x2_t c0 =
vzipq_f16(a0.val[1], a1.val[1]); float16x8x2_t c1 =
vzipq_f16(a0.val[3], a1.val[3]);
float16x8x2_t d0 =
vzipq_f16(b0.val[0], b1.val[0]); float16x8x2_t d1 =
vzipq_f16(c0.val[0], c1.val[0]); float16x8x2_t e0 =
vzipq_f16(d0.val[0], d1.val[0]); float16x8x2_t e1 =
vzipq_f16(d0.val[1], d1.val[1]);
float16x8x2_t f0 =
vzipq_f16(b0.val[1], b1.val[1]); float16x8x2_t f1 =
vzipq_f16(c0.val[1], c1.val[1]); float16x8x2_t g0 =
vzipq_f16(f0.val[0], f1.val[0]); float16x8x2_t g1 =
vzipq_f16(f0.val[1], f1.val[1]);
vst1q_f16(dst_ptr + 0 * ldb, e0.val[0]);
vst1q_f16(dst_ptr + 1 * ldb, e0.val[1]);
vst1q_f16(dst_ptr + 2 * ldb, e1.val[0]);
vst1q_f16(dst_ptr + 3 * ldb, e1.val[1]);
vst1q_f16(dst_ptr + 4 * ldb, g0.val[0]);
vst1q_f16(dst_ptr + 5 * ldb, g0.val[1]);
vst1q_f16(dst_ptr + 6 * ldb, g1.val[0]);
vst1q_f16(dst_ptr + 7 * ldb, g1.val[1]);
}
#endif
}
namespace megdnn {
template <>
void transpose(
const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds,
ptrdiff_t ldd) {
if (lds == -1) {
lds = n;
}
if (ldd == -1) {
ldd = m;
}
for (size_t is = 0; is < n; is += 16) {
for (size_t js = 0; js < m; js += 16) {
auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
for (; i + 4 <= ie; i += 4) {
auto j = js;
for (; j + 4 <= je; j += 4) {
transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd);
}
if (j < je) {
transpose_naive(
src + j * lds + i, dst + i * ldd + j, lds, ldd, 4, je - j);
}
}
if (i < ie) {
transpose_naive(
src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i,
je - js);
}
}
}
}
template <typename dtype>
void transpose_knc2nsck_helper(
const dtype* src, dtype* dst, size_t k, size_t n, size_t c, size_t n_stride) {
if (n_stride == k * c) {
transpose(src, dst, k, n * c);
} else {
for (size_t i = 0; i < n; ++i) {
transpose(src + i * c, dst + i * n_stride, k, c, n * c);
}
}
}
template <>
void transpose_knc2nsck(
const float* src, float* dst, size_t k, size_t n, size_t c, size_t n_stride) {
transpose_knc2nsck_helper(src, dst, k, n, c, n_stride);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
void transpose(
const dt_float16* src, dt_float16* dst, size_t m, size_t n, ptrdiff_t lds,
ptrdiff_t ldd) {
if (lds == -1) {
lds = n;
}
if (ldd == -1) {
ldd = m;
}
for (size_t is = 0; is < n; is += 16) {
for (size_t js = 0; js < m; js += 16) {
auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
for (; i + 8 <= ie; i += 8) {
auto j = js;
for (; j + 8 <= je; j += 8) {
transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd);
}
if (j < je) {
transpose_naive(
src + j * lds + i, dst + i * ldd + j, lds, ldd, 8, je - j);
}
}
if (i < ie) {
transpose_naive(
src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i,
je - js);
}
}
}
}
template <>
void transpose_knc2nsck(
const dt_float16* src, dt_float16* dst, size_t k, size_t n, size_t c,
size_t n_stride) {
transpose_knc2nsck_helper(src, dst, k, n, c, n_stride);
}
#endif
}