#pragma once
#include "src/common/unroll_macro.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define MATRIX_MUL4x4_fp16(sum, a, b) \
sum##0 = vmul_lane_f16(b##0, a##0, 0); \
sum##1 = vmul_lane_f16(b##0, a##1, 0); \
sum##2 = vmul_lane_f16(b##0, a##2, 0); \
sum##3 = vmul_lane_f16(b##0, a##3, 0); \
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##1, a##0, 1)); \
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##1, a##1, 1)); \
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##1, a##2, 1)); \
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##1, a##3, 1)); \
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##2, a##0, 2)); \
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##2, a##1, 2)); \
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##2, a##2, 2)); \
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##2, a##3, 2)); \
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##3, a##0, 3)); \
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##3, a##1, 3)); \
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##3, a##2, 3)); \
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##3, a##3, 3));
#define CONCAT(a, id) a##id
#if MEGDNN_AARCH64
#define TRANSPOSE_4x4(a, ret) \
do { \
auto b00 = vzip1_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b01 = vzip2_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b10 = vzip1_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b11 = vzip2_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto s32b00 = vreinterpret_s32_f16(b00); \
auto s32b01 = vreinterpret_s32_f16(b01); \
auto s32b10 = vreinterpret_s32_f16(b10); \
auto s32b11 = vreinterpret_s32_f16(b11); \
CONCAT(ret, 0).value = vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \
CONCAT(ret, 1).value = vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \
CONCAT(ret, 2).value = vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \
CONCAT(ret, 3).value = vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \
} while (0);
#define TRANSPOSE_4x8(a, ret) \
do { \
auto b00 = vzip1q_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b01 = vzip2q_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b10 = vzip1q_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b11 = vzip2q_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto s32b00 = vreinterpretq_s32_f16(b00); \
auto s32b01 = vreinterpretq_s32_f16(b01); \
auto s32b10 = vreinterpretq_s32_f16(b10); \
auto s32b11 = vreinterpretq_s32_f16(b11); \
auto f16b00 = vreinterpretq_f16_s32( \
vzip1q_s32(s32b00, s32b10)); \
auto f16b01 = vreinterpretq_f16_s32( \
vzip2q_s32(s32b00, s32b10)); \
auto f16b10 = vreinterpretq_f16_s32( \
vzip1q_s32(s32b01, s32b11)); \
auto f16b11 = vreinterpretq_f16_s32( \
vzip2q_s32(s32b01, s32b11)); \
CONCAT(ret, 0).value = vget_low_f16(f16b00); \
CONCAT(ret, 1).value = vget_high_f16(f16b00); \
CONCAT(ret, 2).value = vget_low_f16(f16b01); \
CONCAT(ret, 3).value = vget_high_f16(f16b01); \
CONCAT(ret, 4).value = vget_low_f16(f16b10); \
CONCAT(ret, 5).value = vget_high_f16(f16b10); \
CONCAT(ret, 6).value = vget_low_f16(f16b11); \
CONCAT(ret, 7).value = vget_high_f16(f16b11); \
} while (0);
#define TRANSPOSE_8x4(a, ret) \
do { \
auto b00 = vzip1_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b01 = vzip2_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b10 = vzip1_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b11 = vzip2_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b20 = vzip1_f16(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b21 = vzip2_f16(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b30 = vzip1_f16(CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto b31 = vzip2_f16(CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto s32b00 = vreinterpret_s32_f16(b00); \
auto s32b01 = vreinterpret_s32_f16(b01); \
auto s32b10 = vreinterpret_s32_f16(b10); \
auto s32b11 = vreinterpret_s32_f16(b11); \
auto s32b20 = vreinterpret_s32_f16(b20); \
auto s32b21 = vreinterpret_s32_f16(b21); \
auto s32b30 = vreinterpret_s32_f16(b30); \
auto s32b31 = vreinterpret_s32_f16(b31); \
CONCAT(ret, 0).value = vcombine_f16( \
vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \
vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \
CONCAT(ret, 1).value = vcombine_f16( \
vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \
vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \
CONCAT(ret, 2).value = vcombine_f16( \
vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \
vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \
CONCAT(ret, 3).value = vcombine_f16( \
vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \
vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \
} while (0);
#define TRANSPOSE_8x8(a, ret) \
do { \
auto b00 = vzip1q_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b01 = vzip2q_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b10 = vzip1q_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b11 = vzip2q_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b20 = vzip1q_f16( \
CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b21 = vzip2q_f16( \
CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b30 = vzip1q_f16( \
CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto b31 = vzip2q_f16( \
CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto s32b00 = vreinterpretq_s32_f16(b00); \
auto s32b01 = vreinterpretq_s32_f16(b01); \
auto s32b10 = vreinterpretq_s32_f16(b10); \
auto s32b11 = vreinterpretq_s32_f16(b11); \
auto s32b20 = vreinterpretq_s32_f16(b20); \
auto s32b21 = vreinterpretq_s32_f16(b21); \
auto s32b30 = vreinterpretq_s32_f16(b30); \
auto s32b31 = vreinterpretq_s32_f16(b31); \
auto s64b00 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b00, s32b10)); \
auto s64b01 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b00, s32b10)); \
auto s64b10 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b01, s32b11)); \
auto s64b11 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b01, s32b11)); \
auto s64b20 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b20, s32b30)); \
auto s64b21 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b20, s32b30)); \
auto s64b30 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b21, s32b31)); \
auto s64b31 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b21, s32b31)); \
CONCAT(ret, 0).value = vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \
CONCAT(ret, 1).value = vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \
CONCAT(ret, 2).value = vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \
CONCAT(ret, 3).value = vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \
CONCAT(ret, 4).value = vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \
CONCAT(ret, 5).value = vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \
CONCAT(ret, 6).value = vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \
CONCAT(ret, 7).value = vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \
} while (0);
#else
#define TRANSPOSE_4x4(a, ret) \
do { \
auto b0_01 = vzip_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1_01 = vzip_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \
auto s32b00b10 = vzip_s32(s32b00, s32b10); \
auto s32b01b11 = vzip_s32(s32b01, s32b11); \
CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \
CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \
CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \
CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \
} while (0);
#define TRANSPOSE_4x8(a, ret) \
do { \
auto b0_01 = vzipq_f16( \
CONCAT(a, 0).value, \
CONCAT(a, 1).value); \
auto b1_01 = vzipq_f16( \
CONCAT(a, 2).value, \
CONCAT(a, 3).value); \
auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \
auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \
auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \
auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \
auto s32b00b10 = \
vzipq_s32(s32b00, s32b10); \
auto s32b01b11 = \
vzipq_s32(s32b01, s32b11); \
CONCAT(ret, 0).value = vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \
CONCAT(ret, 1).value = vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \
CONCAT(ret, 2).value = vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \
CONCAT(ret, 3).value = vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \
CONCAT(ret, 4).value = vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \
CONCAT(ret, 5).value = vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \
CONCAT(ret, 6).value = vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \
CONCAT(ret, 7).value = vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \
} while (0);
#define TRANSPOSE_8x4(a, ret) \
do { \
auto b0_01 = vzip_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1_01 = vzip_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2_01 = vzip_f16( \
CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3_01 = vzip_f16( \
CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \
auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \
auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \
auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \
auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \
auto s32b00b10 = vzip_s32(s32b00, s32b10); \
auto s32b01b11 = vzip_s32(s32b01, s32b11); \
auto s32b20b30 = vzip_s32(s32b20, s32b30); \
auto s32b21b31 = vzip_s32(s32b21, s32b31); \
CONCAT(ret, 0).value = vcombine_f16( \
vreinterpret_f16_s32(s32b00b10.val[0]), \
vreinterpret_f16_s32(s32b20b30.val[0])); \
CONCAT(ret, 1).value = vcombine_f16( \
vreinterpret_f16_s32(s32b00b10.val[1]), \
vreinterpret_f16_s32(s32b20b30.val[1])); \
CONCAT(ret, 2).value = vcombine_f16( \
vreinterpret_f16_s32(s32b01b11.val[0]), \
vreinterpret_f16_s32(s32b21b31.val[0])); \
CONCAT(ret, 3).value = vcombine_f16( \
vreinterpret_f16_s32(s32b01b11.val[1]), \
vreinterpret_f16_s32(s32b21b31.val[1])); \
} while (0);
#define TRANSPOSE_8x8(a, ret) \
do { \
auto b00 = vzipq_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b01 = vzipq_f16( \
CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b10 = vzipq_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b11 = vzipq_f16( \
CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b20 = vzipq_f16( \
CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b21 = vzipq_f16( \
CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b30 = vzipq_f16( \
CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto b31 = vzipq_f16( \
CONCAT(a, 6).value, CONCAT(a, 7).value); \
auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \
auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \
auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \
auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \
auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \
auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \
auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \
auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \
auto s32b00b10 = vzipq_s32(s32b00, s32b10); \
auto s32b01b11 = vzipq_s32(s32b01, s32b11); \
auto s32b20b30 = vzipq_s32(s32b20, s32b30); \
auto s32b21b31 = vzipq_s32(s32b21, s32b31); \
CONCAT(ret, 0).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_low_s32(s32b00b10.val[0]), vget_low_s32(s32b20b30.val[0]))); \
CONCAT(ret, 1).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_high_s32(s32b00b10.val[0]), vget_high_s32(s32b20b30.val[0]))); \
CONCAT(ret, 2).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_low_s32(s32b00b10.val[1]), vget_low_s32(s32b20b30.val[1]))); \
CONCAT(ret, 3).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_high_s32(s32b00b10.val[1]), vget_high_s32(s32b20b30.val[1]))); \
CONCAT(ret, 4).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_low_s32(s32b01b11.val[0]), vget_low_s32(s32b21b31.val[0]))); \
CONCAT(ret, 5).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_high_s32(s32b01b11.val[0]), vget_high_s32(s32b21b31.val[0]))); \
CONCAT(ret, 6).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_low_s32(s32b01b11.val[1]), vget_low_s32(s32b21b31.val[1]))); \
CONCAT(ret, 7).value = vreinterpretq_f16_s32(vcombine_s32( \
vget_high_s32(s32b01b11.val[1]), vget_high_s32(s32b21b31.val[1]))); \
} while (0);
#endif
#endif