#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h"
using namespace megdnn;
using namespace armv7;
using namespace armv7::matmul;
namespace {
void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
asm volatile(
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"mov r1, %[output0]\n"
"vld1.32 {d8-d11}, [r1]!\n"
"vld1.32 {d12-d15}, [r1]!\n"
"vld1.32 {d16-d19}, [r1]!\n"
"vld1.32 {d20-d23}, [r1]!\n"
"vld1.32 {d24-d27}, [r1]!\n"
"vld1.32 {d28-d31}, [r1]!\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"b 2f\n"
"1:\n"
"veor.32 q4, q4, q4\n"
"pld [%[output0]]\n"
"veor.32 q5, q4, q4\n"
"veor.32 q6, q4, q4\n"
"veor.32 q7, q4, q4\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"veor.32 q8, q4, q4\n"
"veor.32 q9, q4, q4\n"
"veor.32 q10, q4, q4\n"
"veor.32 q11, q4, q4\n"
"vld1.32 {d4-d7}, [%[b_ptr]]!\n"
"veor.32 q12, q4, q4\n"
"veor.32 q13, q4, q4\n"
"veor.32 q14, q4, q4\n"
"veor.32 q15, q4, q4\n"
"2: \n"
"cmp %[K], #0\n"
"beq 4f\n"
"3:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q8, q0, d6[0]\n"
"vmla.f32 q9, q0, d6[1]\n"
"vmla.f32 q10, q0, d7[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q11, q0, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q12, q0, d4[0]\n"
"vmla.f32 q13, q0, d4[1]\n"
"vmla.f32 q14, q0, d5[0]\n"
"vmla.f32 q15, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q4, q1, d6[0]\n"
"subs %[K], %[K], #1\n"
"vmla.f32 q5, q1, d6[1]\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q8, q1, d4[0]\n"
"vmla.f32 q9, q1, d4[1]\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vmla.f32 q10, q1, d5[0]\n"
"vmla.f32 q11, q1, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q12, q1, d6[0]\n"
"vmla.f32 q13, q1, d6[1]\n"
"vmla.f32 q14, q1, d7[0]\n"
"vmla.f32 q15, q1, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"bne 3b\n"
"4:\n"
"cmp %[oddk], #1\n"
"beq 5f\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q8, q0, d6[0]\n"
"vmla.f32 q9, q0, d6[1]\n"
"vmla.f32 q10, q0, d7[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q11, q0, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q12, q0, d4[0]\n"
"vmla.f32 q13, q0, d4[1]\n"
"vmla.f32 q14, q0, d5[0]\n"
"vmla.f32 q15, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q4, q1, d6[0]\n"
"subs %[K], %[K], #1\n"
"vmla.f32 q5, q1, d6[1]\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q8, q1, d4[0]\n"
"vmla.f32 q9, q1, d4[1]\n"
"vst1.32 {d8-d11}, [%[output0]]!\n"
"vmla.f32 q10, q1, d5[0]\n"
"vmla.f32 q11, q1, d5[1]\n"
"vst1.32 {d12-d15}, [%[output0]]!\n"
"vmla.f32 q12, q1, d6[0]\n"
"vmla.f32 q13, q1, d6[1]\n"
"vst1.32 {d16-d19}, [%[output0]]!\n"
"vmla.f32 q14, q1, d7[0]\n"
"vmla.f32 q15, q1, d7[1]\n"
"vst1.32 {d20-d23}, [%[output0]]!\n"
"vst1.32 {d24-d27}, [%[output0]]!\n"
"vst1.32 {d28-d31}, [%[output0]]!\n"
"b 6f\n"
"5:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q8, q0, d6[0]\n"
"vst1.32 {d8-d11}, [%[output0]]!\n"
"vmla.f32 q9, q0, d6[1]\n"
"vmla.f32 q10, q0, d7[0]\n"
"vst1.32 {d12-d15}, [%[output0]]!\n"
"vmla.f32 q11, q0, d7[1]\n"
"vmla.f32 q12, q0, d4[0]\n"
"vst1.32 {d16-d19}, [%[output0]]!\n"
"vmla.f32 q13, q0, d4[1]\n"
"vst1.32 {d20-d23}, [%[output0]]!\n"
"vmla.f32 q14, q0, d5[0]\n"
"vst1.32 {d24-d27}, [%[output0]]!\n"
"vmla.f32 q15, q0, d5[1]\n"
"vst1.32 {d28-d31}, [%[output0]]!\n"
"6:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r1", "cc", "memory");
}
void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
#define LOAD_C \
"cmp %[n_remain], #4\n" \
"blt 11f\n" \
"vld1.32 {d8-d11}, [r1]!\n" \
"vld1.32 {d12-d15}, [r1]!\n" \
"b 14f\n" \
"11:\n" \
"cmp %[n_remain], #3\n" \
"blt 12f\n" \
"vld1.32 {d8-d11}, [r1]!\n" \
"vld1.32 {d12-d13}, [r1]!\n" \
"b 14f\n" \
"12:\n" \
"cmp %[n_remain], #2\n" \
"blt 13f\n" \
"vld1.32 {d8-d11}, [r1]\n" \
"b 14f\n" \
"13:\n" \
"vld1.32 {d8-d9}, [r1]\n" \
"14:\n"
#define STORE_C \
"cmp %[n_remain], #4\n" \
"blt 21f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"vst1.32 {d12-d15}, [%[output]]!\n" \
"b 24f\n" \
"21:\n" \
"cmp %[n_remain], #3\n" \
"blt 22f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"vst1.32 {d12-d13}, [%[output]]!\n" \
"b 24f\n" \
"22:\n" \
"cmp %[n_remain], #2\n" \
"blt 23f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"b 24f\n" \
"23:\n" \
"vst1.32 {d8-d9}, [%[output]]!\n" \
"24:\n"
asm volatile(
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"mov r1, %[output]\n" LOAD_C
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"b 2f\n"
"1:\n"
"veor.32 q4, q4, q4\n"
"pld [%[output]]\n"
"veor.32 q5, q4, q4\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"veor.32 q6, q4, q4\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"veor.32 q7, q4, q4\n"
"2: \n"
"cmp %[K], #0\n"
"beq 4f\n"
"3:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q5, q0, d4[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q4, q1, d6[0]\n"
"subs %[K], %[K], #1\n"
"vmla.f32 q5, q1, d6[1]\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"bne 3b\n"
"4:\n"
"cmp %[oddk], #1\n"
"beq 5f\n"
"vmla.f32 q4, q0, d4[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q5, q0, d4[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vmla.f32 q4, q1, d6[0]\n"
"vmla.f32 q5, q1, d6[1]\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"b 6f\n"
"5:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"6:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output] "+r"(output),
[n_remain] "+r"(n_remain)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "r1", "cc", "memory");
#undef LOAD_C
#undef STORE_C
}
}
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_pack_4x12);
void sgemm_mk4_pack_4x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int PACK_C_SIZE = 4;
size_t cp_length = (kmax - k0) * PACK_C_SIZE;
for (int m = y0; m < ymax; m += 4) {
const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE;
memcpy(out, src, cp_length * sizeof(float));
out += cp_length;
}
}
void sgemm_mk4_pack_4x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B);
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
float tmpbuff[16] = {0.0f};
constexpr int PACK_C_SIZE = 4;
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
float* outptr_base = out;
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
prefetch_3x(inptr);
int x = x0;
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
transpose_1x12_4_s(inptr, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
transpose_1x4_4_s(inptr, outptr_interleave);
outptr += ksize4;
}
if (x < xmax) {
memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
auto outptr_interleave = outptr;
const float* tmp_ptr = &tmpbuff[0];
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);
outptr += ksize4;
}
outptr_base += 12 * PACK_C_SIZE;
outptr_base4 += 4 * PACK_C_SIZE;
}
}
void sgemm_mk4_pack_4x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
constexpr int PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K4 = K * 4;
size_t m = 0;
for (; m < M; m += A_INTERLEAVE) {
float* output = C + (m / 4 * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += PACK_C_SIZE * B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += PACK_C_SIZE * 4;
cur_packB += K4;
}
packA += K4;
}
}