#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace aarch64 {
namespace matmul_8x8x4 {
#define SUB_LANE(v0, v1, lane1, v2, vtmp) \
"dup v" #vtmp ".4s, v" #v1 ".s[" #lane1 \
"]\n" \
"sub v" #v0 ".4s, v" #v0 ".4s, v" #vtmp \
".4s\n" \
"sub v" #v0 ".4s, v" #v0 ".4s, v" #v2 ".4s\n"
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x8(
const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) {
K /= 4;
const uint8_t* a_ptr = packA;
const uint8_t* b_ptr = packB;
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t a1;
int32x4_t b0;
int32x4_t b1;
int32x4_t a0a;
int32x4_t a1a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
int32_t* outptr2;
int32_t* outptr3;
int32_t* outptr4;
int32_t* outptr5;
int32_t* outptr6;
int32_t* outptr7;
asm volatile(
"add %[outptr1], %[outptr0], %x[LDC]\n"
"add %[outptr2], %[outptr1], %x[LDC]\n"
"add %[outptr3], %[outptr2], %x[LDC]\n"
"add %[outptr4], %[outptr3], %x[LDC]\n"
"add %[outptr5], %[outptr4], %x[LDC]\n"
"add %[outptr6], %[outptr5], %x[LDC]\n"
"add %[outptr7], %[outptr6], %x[LDC]\n"
"dup v24.16b, %w[zero_point_B] \n"
"dup v25.16b, %w[zero_point_A] \n"
"dup v26.4s, %w[zAB] \n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n"
"ldp q6, q14, [%[outptr0]]\n"
"ldp q7, q15, [%[outptr1]]\n"
"ldp q8, q16, [%[outptr2]]\n"
"ldp q9, q17, [%[outptr3]]\n"
"ldp q10, q18, [%[outptr4]]\n"
"ldp q11, q19, [%[outptr5]]\n"
"ldp q12, q20, [%[outptr6]]\n"
"ldp q13, q21, [%[outptr7]]\n"
"b 2f\n"
"1:\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[a1], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"udot v27.4s, %[a0].16b, v24.16b\n"
"udot v28.4s, %[a1].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v30.4s, %[b1].16b, v25.16b\n"
"udot v6.4s, %[b0].16b, %[a0].4b[0]\n"
"udot v7.4s, %[b0].16b, %[a0].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v10.4s, %[b0].16b, %[a1].4b[0]\n"
"udot v11.4s, %[b0].16b, %[a1].4b[1]\n"
"udot v12.4s, %[b0].16b, %[a1].4b[2]\n"
"udot v13.4s, %[b0].16b, %[a1].4b[3]\n"
"udot v14.4s, %[b1].16b, %[a0].4b[0]\n"
"udot v15.4s, %[b1].16b, %[a0].4b[1]\n"
"udot v16.4s, %[b1].16b, %[a0].4b[2]\n"
"udot v17.4s, %[b1].16b, %[a0].4b[3]\n"
"udot v18.4s, %[b1].16b, %[a1].4b[0]\n"
"udot v19.4s, %[b1].16b, %[a1].4b[1]\n"
"udot v20.4s, %[b1].16b, %[a1].4b[2]\n"
"udot v21.4s, %[b1].16b, %[a1].4b[3]\n"
"cbz %w[k], 4f\n"
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[a1], [%[a_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[a1a], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"udot v27.4s, %[a0].16b, v24.16b\n"
"udot v28.4s, %[a1].16b, v24.16b\n"
"udot v27.4s, %[a0a].16b, v24.16b\n"
"udot v28.4s, %[a1a].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v30.4s, %[b1].16b, v25.16b\n"
"udot v6.4s, %[b0].16b, %[a0].4b[0]\n"
"udot v7.4s, %[b0].16b, %[a0].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v10.4s, %[b0].16b, %[a1].4b[0]\n"
"udot v11.4s, %[b0].16b, %[a1].4b[1]\n"
"udot v12.4s, %[b0].16b, %[a1].4b[2]\n"
"udot v13.4s, %[b0].16b, %[a1].4b[3]\n"
"udot v14.4s, %[b1].16b, %[a0].4b[0]\n"
"udot v15.4s, %[b1].16b, %[a0].4b[1]\n"
"udot v16.4s, %[b1].16b, %[a0].4b[2]\n"
"udot v17.4s, %[b1].16b, %[a0].4b[3]\n"
"udot v18.4s, %[b1].16b, %[a1].4b[0]\n"
"udot v19.4s, %[b1].16b, %[a1].4b[1]\n"
"udot v20.4s, %[b1].16b, %[a1].4b[2]\n"
"udot v21.4s, %[b1].16b, %[a1].4b[3]\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v30.4s, %[b1].16b, v25.16b\n"
"udot v6.4s, %[b0].16b, %[a0a].4b[0]\n"
"udot v7.4s, %[b0].16b, %[a0a].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0a].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0a].4b[3]\n"
"udot v10.4s, %[b0].16b, %[a1a].4b[0]\n"
"udot v11.4s, %[b0].16b, %[a1a].4b[1]\n"
"udot v12.4s, %[b0].16b, %[a1a].4b[2]\n"
"udot v13.4s, %[b0].16b, %[a1a].4b[3]\n"
"udot v14.4s, %[b1].16b, %[a0a].4b[0]\n"
"udot v15.4s, %[b1].16b, %[a0a].4b[1]\n"
"udot v16.4s, %[b1].16b, %[a0a].4b[2]\n"
"udot v17.4s, %[b1].16b, %[a0a].4b[3]\n"
"udot v18.4s, %[b1].16b, %[a1a].4b[0]\n"
"udot v19.4s, %[b1].16b, %[a1a].4b[1]\n"
"udot v20.4s, %[b1].16b, %[a1a].4b[2]\n"
"udot v21.4s, %[b1].16b, %[a1a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n"
"sub v27.4s, v27.4s, v26.4s\n"
"sub v28.4s, v28.4s, v26.4s\n"
SUB_LANE(6, 27, 0, 29, 23)
SUB_LANE(14, 27, 0, 30, 23)
SUB_LANE(7, 27, 1, 29, 23)
SUB_LANE(15, 27, 1, 30, 23)
SUB_LANE(8, 27, 2, 29, 23)
SUB_LANE(16, 27, 2, 30, 23)
SUB_LANE(9, 27, 3, 29, 23)
SUB_LANE(17, 27, 3, 30, 23)
SUB_LANE(10, 28, 0, 29, 23)
SUB_LANE(18, 28, 0, 30, 23)
SUB_LANE(11, 28, 1, 29, 23)
SUB_LANE(19, 28, 1, 30, 23)
SUB_LANE(12, 28, 2, 29, 23)
SUB_LANE(20, 28, 2, 30, 23)
SUB_LANE(13, 28, 3, 29, 23)
SUB_LANE(21, 28, 3, 30, 23)
"stp q6, q14, [%[outptr0]]\n"
"stp q7, q15, [%[outptr1]]\n"
"stp q8, q16, [%[outptr2]]\n"
"stp q9, q17, [%[outptr3]]\n"
"stp q10, q18, [%[outptr4]]\n"
"stp q11, q19, [%[outptr5]]\n"
"stp q12, q20, [%[outptr6]]\n"
"stp q13, q21, [%[outptr7]]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1),
[a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1),
[k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0),
[zero_point_A] "+r"(zero_point_A), [zero_point_B] "+r"(zero_point_B),
[zAB] "+r"(zAB), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7)
:
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20", "v21", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "cc", "memory");
}
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x8(
const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain, uint8_t zero_point_A, uint8_t zero_point_B,
uint32_t zAB) {
K /= 4;
const uint8_t* a_ptr = packA;
const uint8_t* b_ptr = packB;
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t b0;
int32x4_t b1;
int32x4_t a0a;
int32x4_t b0a;
int32x4_t b1a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
int32_t* outptr2;
int32_t* outptr3;
size_t x0;
#define LOAD_LINE(v1, v2, m) \
"cbz %[x0], 100f\n" \
"ldp " v1 "," v2 ", [%[outptr" m "]]\n" \
"subs %[x0], %[x0], #1\n"
#define LOAD_C \
"mov %[x0], %x[m_remain]\n" \
LOAD_LINE("q6", "q10", "0") \
LOAD_LINE("q7", "q11", "1") \
LOAD_LINE("q8", "q12", "2") \
LOAD_LINE("q9", "q13", "3") \
"100:\n"
#define STORE_LINE(v1, v2, m) \
"cbz %[x0], 101f\n" \
"stp " v1 "," v2", [%[outptr" m "]]\n" \
"subs %[x0], %[x0], #1\n"
#define STORE_C \
"mov %[x0], %x[m_remain]\n" \
STORE_LINE("q6", "q10", "0") \
STORE_LINE("q7", "q11", "1") \
STORE_LINE("q8", "q12", "2") \
STORE_LINE("q9", "q13", "3") \
"101:\n"
asm volatile(
"add %[outptr1], %[outptr0], %x[LDC]\n"
"add %[outptr2], %[outptr1], %x[LDC]\n"
"add %[outptr3], %[outptr2], %x[LDC]\n"
"dup v24.16b, %w[zero_point_B] \n"
"dup v25.16b, %w[zero_point_A] \n"
"dup v26.4s, %w[zAB] \n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"udot v28.4s, %[a0].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v30.4s, %[b1].16b, v25.16b\n"
"udot v6.4s, %[b0].16b, %[a0].4b[0]\n"
"udot v7.4s, %[b0].16b, %[a0].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v10.4s, %[b1].16b, %[a0].4b[0]\n"
"udot v11.4s, %[b1].16b, %[a0].4b[1]\n"
"udot v12.4s, %[b1].16b, %[a0].4b[2]\n"
"udot v13.4s, %[b1].16b, %[a0].4b[3]\n"
"cbz %w[k], 4f\n"
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[b0a], [%[b_ptr]], #16\n"
"ldr %q[b1a], [%[b_ptr]], #16\n"
"udot v28.4s, %[a0].16b, v24.16b\n"
"udot v28.4s, %[a0a].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v30.4s, %[b1].16b, v25.16b\n"
"udot v29.4s, %[b0a].16b, v25.16b\n"
"udot v30.4s, %[b1a].16b, v25.16b\n"
"udot v6.4s, %[b0].16b, %[a0].4b[0]\n"
"udot v7.4s, %[b0].16b, %[a0].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v10.4s, %[b1].16b, %[a0].4b[0]\n"
"udot v11.4s, %[b1].16b, %[a0].4b[1]\n"
"udot v12.4s, %[b1].16b, %[a0].4b[2]\n"
"udot v13.4s, %[b1].16b, %[a0].4b[3]\n"
"udot v6.4s , %[b0a].16b, %[a0a].4b[0]\n"
"udot v7.4s , %[b0a].16b, %[a0a].4b[1]\n"
"udot v8.4s, %[b0a].16b, %[a0a].4b[2]\n"
"udot v9.4s, %[b0a].16b, %[a0a].4b[3]\n"
"udot v10.4s, %[b1a].16b, %[a0a].4b[0]\n"
"udot v11.4s, %[b1a].16b, %[a0a].4b[1]\n"
"udot v12.4s, %[b1a].16b, %[a0a].4b[2]\n"
"udot v13.4s, %[b1a].16b, %[a0a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n"
"sub v29.4s, v29.4s, v26.4s\n"
"sub v30.4s, v30.4s, v26.4s\n"
SUB_LANE(6, 28, 0, 29, 23)
SUB_LANE(10, 28, 0, 30, 23)
SUB_LANE(7, 28, 1, 29, 23)
SUB_LANE(11, 28, 1, 30, 23)
SUB_LANE(8, 28, 2, 29, 23)
SUB_LANE(12, 28, 2, 30, 23)
SUB_LANE(9, 28, 3, 29, 23)
SUB_LANE(13, 28, 3, 30, 23)
STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[m_remain] "+r"(m_remain), [zero_point_A] "+r"(zero_point_A),
[zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), [LDC] "+r"(LDC),
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1),
[b0a] "=w"(b0a), [b1a] "=w"(b1a), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "=r"(x0)
:
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", "v25",
"v26", "v28", "v29", "v30", "memory", "cc");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(
const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain, uint8_t zero_point_A, uint8_t zero_point_B,
uint32_t zAB) {
K /= 4;
const uint8_t* a_ptr = packA;
const uint8_t* b_ptr = packB;
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t a1;
int32x4_t b0;
int32x4_t b0a;
int32x4_t a0a;
int32x4_t a1a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
int32_t* outptr2;
int32_t* outptr3;
int32_t* outptr4;
int32_t* outptr5;
int32_t* outptr6;
int32_t* outptr7;
size_t x0;
#define LOAD_LINE(reg_index, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ldr q" reg_index ", [%[x0]] \n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
"101" n ":\n"
#define LOAD_C \
LOAD_LINE("6", "0") \
LOAD_LINE("7", "1") \
LOAD_LINE("8", "2") \
LOAD_LINE("9", "3") \
LOAD_LINE("10", "4") \
LOAD_LINE("11", "5") \
LOAD_LINE("12", "6") \
LOAD_LINE("13", "7")
#define STORE_LINE(reg_index, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 102" n "f\n" \
"str q" reg_index ", [%[x0]]\n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
"cmp %w[n_remain], #1\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
"cmp %w[n_remain], #2\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
"103" n ":\n"
#define STORE_C \
STORE_LINE("6", "0") \
STORE_LINE("7", "1") \
STORE_LINE("8", "2") \
STORE_LINE("9", "3") \
STORE_LINE("10", "4") \
STORE_LINE("11", "5") \
STORE_LINE("12", "6") \
STORE_LINE("13", "7")
asm volatile(
"add %[outptr1], %[outptr0], %x[LDC]\n"
"add %[outptr2], %[outptr1], %x[LDC]\n"
"add %[outptr3], %[outptr2], %x[LDC]\n"
"add %[outptr4], %[outptr3], %x[LDC]\n"
"add %[outptr5], %[outptr4], %x[LDC]\n"
"add %[outptr6], %[outptr5], %x[LDC]\n"
"add %[outptr7], %[outptr6], %x[LDC]\n"
"dup v24.16b, %w[zero_point_B] \n"
"dup v25.16b, %w[zero_point_A] \n"
"dup v26.4s, %w[zAB] \n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[a1], [%[a_ptr]], #16\n"
"udot v27.4s, %[a0].16b, v24.16b\n"
"udot v28.4s, %[a1].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v6.4s , %[b0].16b, %[a0].4b[0]\n"
"udot v7.4s , %[b0].16b, %[a0].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v10.4s, %[b0].16b, %[a1].4b[0]\n"
"udot v11.4s, %[b0].16b, %[a1].4b[1]\n"
"udot v12.4s, %[b0].16b, %[a1].4b[2]\n"
"udot v13.4s, %[b0].16b, %[a1].4b[3]\n"
"cbz %w[k], 4f\n"
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[a1], [%[a_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[a1a], [%[a_ptr]], #16\n"
"ldr %q[b0a], [%[b_ptr]], #16\n"
"udot v27.4s, %[a0].16b, v24.16b\n"
"udot v28.4s, %[a1].16b, v24.16b\n"
"udot v27.4s, %[a0a].16b, v24.16b\n"
"udot v28.4s, %[a1a].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v29.4s, %[b0a].16b, v25.16b\n"
"udot v6.4s , %[b0].16b, %[a0].4b[0]\n"
"udot v7.4s , %[b0].16b, %[a0].4b[1]\n"
"udot v8.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v9.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v10.4s, %[b0].16b, %[a1].4b[0]\n"
"udot v11.4s, %[b0].16b, %[a1].4b[1]\n"
"udot v12.4s, %[b0].16b, %[a1].4b[2]\n"
"udot v13.4s, %[b0].16b, %[a1].4b[3]\n"
"udot v6.4s , %[b0a].16b, %[a0a].4b[0]\n"
"udot v7.4s , %[b0a].16b, %[a0a].4b[1]\n"
"udot v8.4s, %[b0a].16b, %[a0a].4b[2]\n"
"udot v9.4s, %[b0a].16b, %[a0a].4b[3]\n"
"udot v10.4s, %[b0a].16b, %[a1a].4b[0]\n"
"udot v11.4s, %[b0a].16b, %[a1a].4b[1]\n"
"udot v12.4s, %[b0a].16b, %[a1a].4b[2]\n"
"udot v13.4s, %[b0a].16b, %[a1a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n"
"sub v27.4s, v27.4s, v26.4s\n"
"sub v28.4s, v28.4s, v26.4s\n"
SUB_LANE(6, 27, 0, 29, 23)
SUB_LANE(7, 27, 1, 29, 23)
SUB_LANE(8, 27, 2, 29, 23)
SUB_LANE(9, 27, 3, 29, 23)
SUB_LANE(10, 28, 0, 29, 23)
SUB_LANE(11, 28, 1, 29, 23)
SUB_LANE(12, 28, 2, 29, 23)
SUB_LANE(13, 28, 3, 29, 23)
STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[zero_point_A] "+r"(zero_point_A), [zero_point_B] "+r"(zero_point_B),
[zAB] "+r"(zAB), [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a),
[a1a] "=w"(a1a), [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4),
[outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7),
[x0] "=r"(x0)
:
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "memory", "cc");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(
const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain, uint8_t zero_point_A,
uint8_t zero_point_B, uint32_t zAB) {
K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t a0a;
int32x4_t b0;
int32x4_t b0a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
int32_t* outptr2;
int32_t* outptr3;
size_t x0, x1;
#define LOAD_LINE(reg_index, n) \
"cbz %[x1], 102f\n" \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ldr q" reg_index ", [%[x0]]\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
"101" n ":\n" \
"subs %[x1], %[x1], #1\n"
#define LOAD_C \
"mov %[x1], %x[m_remain]\n" \
LOAD_LINE("4", "0") \
LOAD_LINE("5", "1") \
LOAD_LINE("6", "2") \
LOAD_LINE("7", "3") \
"102:\n"
#define STORE_LINE(reg_index, n) \
"cbz %[x1], 105f\n" \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 103" n "f\n" \
"str q" reg_index ", [%[x0]]\n" \
"b 104" n "f\n" \
"103" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 104" n "f\n" \
"st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
"cmp %w[n_remain], #1\n" \
"beq 104" n "f\n" \
"st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
"cmp %w[n_remain], #2\n" \
"beq 104" n "f\n" \
"st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
"104" n ":\n" \
"subs %[x1], %[x1], #1\n"
#define STORE_C \
"mov %[x1], %x[m_remain]\n" \
STORE_LINE("4", "0") \
STORE_LINE("5", "1") \
STORE_LINE("6", "2") \
STORE_LINE("7", "3") \
"105:\n"
asm volatile(
"add %[outptr1], %[outptr0], %x[LDC]\n"
"add %[outptr2], %[outptr1], %x[LDC]\n"
"add %[outptr3], %[outptr2], %x[LDC]\n"
"dup v24.16b, %w[zero_point_B] \n"
"dup v25.16b, %w[zero_point_A] \n"
"dup v26.4s, %w[zAB] \n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"udot v28.4s, %[a0].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v4.4s , %[b0].16b, %[a0].4b[0]\n"
"udot v5.4s , %[b0].16b, %[a0].4b[1]\n"
"udot v6.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v7.4s, %[b0].16b, %[a0].4b[3]\n"
"cbz %w[k], 4f\n"
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[b0a], [%[b_ptr]], #16\n"
"udot v28.4s, %[a0].16b, v24.16b\n"
"udot v28.4s, %[a0a].16b, v24.16b\n"
"udot v29.4s, %[b0].16b, v25.16b\n"
"udot v29.4s, %[b0a].16b, v25.16b\n"
"udot v4.4s , %[b0].16b, %[a0].4b[0]\n"
"udot v5.4s , %[b0].16b, %[a0].4b[1]\n"
"udot v6.4s, %[b0].16b, %[a0].4b[2]\n"
"udot v7.4s, %[b0].16b, %[a0].4b[3]\n"
"udot v4.4s , %[b0a].16b, %[a0a].4b[0]\n"
"udot v5.4s , %[b0a].16b, %[a0a].4b[1]\n"
"udot v6.4s, %[b0a].16b, %[a0a].4b[2]\n"
"udot v7.4s, %[b0a].16b, %[a0a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n"
"sub v28.4s, v28.4s, v26.4s\n"
SUB_LANE(4, 28, 0, 29, 23)
SUB_LANE(5, 28, 1, 29, 23)
SUB_LANE(6, 28, 2, 29, 23)
SUB_LANE(7, 28, 3, 29, 23)
STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC),
[zero_point_A] "+r"(zero_point_A), [zero_point_B] "+r"(zero_point_B),
[zAB] "+r"(zAB), [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0),
[a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "=r"(x0),
[x1] "=r"(x1)
:
: "v4", "v5", "v6", "v7", "v23", "v24", "v25", "v26", "v28", "v29",
"memory", "cc");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
#undef SUB_LANE
static void gemm_u8_8x8_transpose_pack_helper(
uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
uint8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(uint8_t) * 16);
const int ksize = kmax - k0;
const int ksize4 = round_up(ksize, 4) * 4;
const int ksize8 = ksize4 * 2;
uint8_t* outptr = out;
uint8_t* outptr_base = out;
uint8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8;
int k = k0;
for (; k < kmax; k += 4) {
const uint8_t* inptr0 = in + k * ldin + x0;
const uint8_t* inptr1 = inptr0 + ldin;
const uint8_t* inptr2 = inptr1 + ldin;
const uint8_t* inptr3 = inptr2 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
int x = x0;
outptr = outptr_base;
for (; x + 7 < xmax; x += 8) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
outptr += ksize8;
}
outptr = outptr_base4;
for (; x + 3 < xmax; x += 4) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4);
outptr += ksize4;
}
if (x < xmax) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}
outptr_base += 8 * 4;
outptr_base4 += 4 * 4;
}
}
static void gemm_u8_8x8_interleave_pack_helper(
uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
uint8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(uint8_t) * 16);
int y = y0;
for (; y + 7 < ymax; y += 8) {
const uint8_t* inptr0 = inptr + y * ldin + k0;
const uint8_t* inptr1 = inptr0 + ldin;
const uint8_t* inptr2 = inptr1 + ldin;
const uint8_t* inptr3 = inptr2 + ldin;
const uint8_t* inptr4 = inptr3 + ldin;
const uint8_t* inptr5 = inptr4 + ldin;
const uint8_t* inptr6 = inptr5 + ldin;
const uint8_t* inptr7 = inptr6 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
prefetch_2x(inptr4);
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);
int K = kmax - k0;
for (; K > 15; K -= 16) {
interleave_8x4_4_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}
if (K > 0) {
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, K);
}
}
for (; y < ymax; y += 4) {
const uint8_t* inptr0 = inptr + y * ldin + k0;
const uint8_t* inptr1 = inptr0 + ldin;
const uint8_t* inptr2 = inptr1 + ldin;
const uint8_t* inptr3 = inptr2 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
int K = kmax - k0;
for (; K > 15; K -= 16) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr);
}
if (K > 0) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K);
}
}
}
} } }
#endif