#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
namespace {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gemv_naive_n(
const uint8_t* __restrict A, const uint8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride,
uint8_t zero_point_A, uint8_t zero_point_B) {
int32_t zAB =
static_cast<int32_t>(zero_point_A) * static_cast<int32_t>(zero_point_B) * K;
uint8x16_t zAq = vdupq_n_u8(zero_point_A);
uint8x16_t zBq = vdupq_n_u8(zero_point_B);
uint8x8_t zA = vdup_n_u8(zero_point_A);
uint8x8_t zB = vdup_n_u8(zero_point_B);
megdnn_assert(N == 1 && Bstride == 1);
size_t m = 0;
for (; m + 2 <= M; m += 2) {
int32_t acc_zA, acc_zB, acc_zB2;
int32_t acc[4];
size_t k = 0;
uint32x4_t acc_neon = vdupq_n_u32(0);
{
uint32x4_t acc_zA_neon = vdupq_n_u32(0);
uint32x4_t acc_zB_neon = vdupq_n_u32(0);
uint32x4_t acc_zB2_neon = vdupq_n_u32(0);
for (; k + 16 <= K; k += 16) {
uint8x16_t elem = vld1q_u8(A + m * Astride + k);
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, elem);
uint64x2_t a0 = vreinterpretq_u64_u8(elem);
elem = vld1q_u8(A + (m + 1) * Astride + k);
acc_zB2_neon = vdotq_u32(acc_zB2_neon, zBq, elem);
uint64x2_t a1 = vreinterpretq_u64_u8(elem);
uint8x16_t a2 = vreinterpretq_u8_u64(vzip1q_u64(a0, a1));
uint8x16_t a3 = vreinterpretq_u8_u64(vzip2q_u64(a0, a1));
elem = vld1q_u8(B + k);
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, elem);
uint64x2_t b0 = vreinterpretq_u64_u8(elem);
uint8x16_t b2 = vreinterpretq_u8_u64(vzip1q_u64(b0, b0));
uint8x16_t b3 = vreinterpretq_u8_u64(vzip2q_u64(b0, b0));
acc_neon = vdotq_u32(acc_neon, a2, b2);
acc_neon = vdotq_u32(acc_neon, a3, b3);
}
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon));
acc_zA = vaddvq_u32(acc_zA_neon);
acc_zB = vaddvq_u32(acc_zB_neon);
acc_zB2 = vaddvq_u32(acc_zB2_neon);
}
{
uint32x2_t acc_zA_neon = vdup_n_u32(0);
uint32x2_t acc_zB_neon = vdup_n_u32(0);
uint32x2_t acc_zB2_neon = vdup_n_u32(0);
for (; k + 8 <= K; k += 8) {
uint8x8_t a0 = vld1_u8(A + m * Astride + k);
uint8x8_t a1 = vld1_u8(A + (m + 1) * Astride + k);
uint8x8_t b0 = vld1_u8(B + k);
uint32x2_t zero = vdup_n_u32(0);
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0));
zero = vdup_n_u32(0);
acc[3] += vaddv_u32(vdot_u32(zero, a1, b0));
acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB);
acc_zB2_neon = vdot_u32(acc_zB2_neon, a1, zB);
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA);
}
acc_zA += vaddv_u32(acc_zA_neon);
acc_zB += vaddv_u32(acc_zB_neon);
acc_zB2 += vaddv_u32(acc_zB2_neon);
}
for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k];
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A;
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B;
acc_zB2 += static_cast<int32_t>(A[(m + 1) * Astride + k]) * zero_point_B;
}
C[m * Cstride] = acc[0] + acc[1] + zAB - acc_zA - acc_zB;
C[(m + 1) * Cstride] = acc[2] + acc[3] + zAB - acc_zA - acc_zB2;
}
for (; m < M; ++m) {
int32_t acc[4];
int32_t acc_zA, acc_zB;
uint32x4_t acc_neon = vdupq_n_u32(0);
size_t k = 0;
{
uint32x4_t acc_zA_neon = vdupq_n_u32(0);
uint32x4_t acc_zB_neon = vdupq_n_u32(0);
for (; k + 16 <= K; k += 16) {
uint8x16_t a0 = vld1q_u8(A + m * Astride + k);
uint8x16_t b0 = vld1q_u8(B + k);
acc_neon = vdotq_u32(acc_neon, a0, b0);
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, a0);
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, b0);
}
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon));
acc_zA = vaddvq_u32(acc_zA_neon);
acc_zB = vaddvq_u32(acc_zB_neon);
}
{
uint32x2_t acc_zA_neon = vdup_n_u32(0);
uint32x2_t acc_zB_neon = vdup_n_u32(0);
for (; k + 8 <= K; k += 8) {
uint8x8_t a0 = vld1_u8(A + m * Astride + k);
uint8x8_t b0 = vld1_u8(B + k);
uint32x2_t zero = vdup_n_u32(0);
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0));
acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB);
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA);
}
acc_zA += vaddv_u32(acc_zA_neon);
acc_zB += vaddv_u32(acc_zB_neon);
}
for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A;
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B;
}
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB;
}
}
}
bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8(
bool transposeA, bool transposeB, size_t M, size_t N, size_t K,
size_t , size_t LDB, size_t ) {
if (transposeA)
return false;
if (transposeB)
return false;
MEGDNN_MARK_USED_VAR(K);
MEGDNN_MARK_USED_VAR(M);
return (N == 1 && LDB == 1);
}
void megdnn::aarch64::matmul::gemv_like_quint8(
const uint8_t* __restrict A, const uint8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride,
uint8_t zero_point_A, uint8_t zero_point_B) {
megdnn_assert(N == 1);
return gemv_naive_n(
A, B, C, M, N, K, Astride, Bstride, Cstride, zero_point_A, zero_point_B);
}
#endif