static inline void quantize_block_f32_q8_0_flat(
float * restrict x,
uint8_t * restrict y_quants,
__fp16 * restrict y_scales,
uint32_t block_idx
) {
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vzero();
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
* (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8;
HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2);
HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2);
HVX_Vector v_scales = Q6_V_lo_W(vp2);
hvx_vec_store_u(y_scales + block_idx * 4, 8, v_scales);
}
static inline void quantize_block_f32_q8_1_flat(
float * restrict x,
uint8_t * restrict y_quants,
__fp16 * restrict y_scales,
uint32_t block_idx
) {
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vzero();
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
const HVX_Vector ones = Q6_Vb_vsplat_R(1);
HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones);
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4));
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8));
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16));
* (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8;
HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2);
HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2);
HVX_Vector v_scales = Q6_V_lo_W(vp2);
HVX_VectorPair v_deal1 = Q6_W_vdeal_VVR(v_sums, v_sums, -4);
HVX_Vector v_even1 = Q6_V_lo_W(v_deal1);
HVX_VectorPair v_deal2 = Q6_W_vdeal_VVR(v_even1, v_even1, -4);
HVX_Vector v_even2 = Q6_V_lo_W(v_deal2);
HVX_VectorPair v_deal3 = Q6_W_vdeal_VVR(v_even2, v_even2, -4);
HVX_Vector v_sums_shuffled = Q6_V_lo_W(v_deal3);
HVX_Vector v_sums_sf = Q6_Vsf_equals_Vw(v_sums_shuffled);
HVX_Vector v_sums_hf = hvx_vec_f32_to_f16(v_sums_sf, Q6_V_vzero());
HVX_Vector v_prod = hvx_vec_mul_f16_f16(v_scales, v_sums_hf);
HVX_VectorPair vp_scales = Q6_W_vshuff_VVR(v_prod, v_scales, -2);
HVX_Vector v_final = Q6_V_lo_W(vp_scales);
hvx_vec_store_u(y_scales + block_idx * 8, 16, v_final);
}
static inline void quantize_row_f32_q8_0_flat(float * restrict x, uint8_t * restrict y, uint32_t k) {
assert(k % 32 == 0);
const uint32_t quants_size = hex_round_up(k, 128);
uint8_t * restrict y_quants = y;
__fp16 * restrict y_scales = (__fp16 *) (y + quants_size);
const uint32_t nb = (k + 127) / 128;
for (uint32_t i = 0; i < nb; i++) {
quantize_block_f32_q8_0_flat(x + i * 128, y_quants, y_scales, i);
}
}
static inline void quantize_row_f32_q8_1_flat(float * restrict x, uint8_t * restrict y, uint32_t k) {
assert(k % 32 == 0);
const uint32_t quants_size = hex_round_up(k, 128);
uint8_t * restrict y_quants = y;
__fp16 * restrict y_scales = (__fp16 *) (y + quants_size);
const uint32_t nb = (k + 127) / 128;
for (uint32_t i = 0; i < nb; i++) {
quantize_block_f32_q8_1_flat(x + i * 128, y_quants, y_scales, i);
}
}
static inline void quantize_f32_q8_0_flat_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_row_size,
size_t dst_row_size
) {
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float));
hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float));
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
hvx_copy_f32_aa(tmp_data, src_data, ne0);
quantize_row_f32_q8_0_flat((float *) tmp_data, dst_data, ne0);
dst_data += dst_row_size;
src_data += src_row_size;
}
}
static inline void quantize_f32_q8_1_flat_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_row_size,
size_t dst_row_size
) {
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float));
hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float));
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
hvx_copy_f32_aa(tmp_data, src_data, ne0);
quantize_row_f32_q8_1_flat((float *) tmp_data, dst_data, ne0);
dst_data += dst_row_size;
src_data += src_row_size;
}
}
static inline void quantize_f32_f32_flat_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_stride,
size_t dst_stride
) {
(void) tmp_data;
const size_t src_row_size = ne0 * sizeof(float);
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_stride, 2);
hvx_copy_f32_au(dst_data, src_data, ne0);
dst_data += dst_stride;
src_data += src_stride;
}
}
static inline void quantize_f32_f16_flat_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_stride,
size_t dst_stride
) {
(void) tmp_data;
const size_t src_row_size = ne0 * sizeof(float);
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_stride, 2);
hvx_copy_f16_f32_au(dst_data, src_data, ne0);
dst_data += dst_stride;
src_data += src_stride;
}
}
static inline void quantize_f16_f16_flat_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_stride,
size_t dst_stride
) {
(void) tmp_data;
const size_t src_row_size = ne0 * sizeof(float);
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_stride, 2);
hvx_copy_f16_au(dst_data, src_data, ne0);
dst_data += dst_stride;
src_data += src_stride;
}
}
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
HVX_Vector i8 = Q6_Vb_vsplat_R(8);
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128);
HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32);
HVX_Vector v_act_rep[8];
v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl);
v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl);
v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl);
v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl);
v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl);
v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl);
v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl);
v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl);
HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, i8);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = vptr[4];
__fp16 scale_a_val = y_scales[kt];
HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val));
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
HVX_Vector i8 = Q6_Vb_vsplat_R(8);
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size);
const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128);
HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128);
HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32);
HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32);
HVX_Vector v_act0_rep[8];
v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl);
v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl);
v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl);
v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl);
v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl);
v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl);
v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl);
v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl);
HVX_Vector v_act1_rep[8];
v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl);
v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl);
v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl);
v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl);
v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl);
v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl);
v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl);
v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl);
HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, i8);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = vptr[4];
__fp16 scale_a0_val = y0_scales[kt];
__fp16 scale_a1_val = y1_scales[kt];
HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val));
HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val));
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128);
HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32);
HVX_Vector v_act_rep[8];
v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl);
v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl);
v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl);
v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl);
v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl);
v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl);
v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl);
v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl);
HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, Q6_V_vzero());
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_offset = vptr[4];
HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2);
HVX_Vector v_scale = Q6_V_lo_W(p_deal);
HVX_Vector v_offset = Q6_V_hi_W(p_deal);
__fp16 scale_a_val = y_scales[kt * 2 + 0];
__fp16 sum_a_val = y_scales[kt * 2 + 1];
HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val));
HVX_Vector v_sum_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a_val));
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a);
HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a);
HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size);
const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128);
HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128);
HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32);
HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32);
HVX_Vector v_act0_rep[8];
v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl);
v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl);
v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl);
v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl);
v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl);
v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl);
v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl);
v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl);
HVX_Vector v_act1_rep[8];
v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl);
v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl);
v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl);
v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl);
v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl);
v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl);
v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl);
v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl);
HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, Q6_V_vzero());
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_offset = vptr[4];
HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2);
HVX_Vector v_scale = Q6_V_lo_W(p_deal);
HVX_Vector v_offset = Q6_V_hi_W(p_deal);
__fp16 scale_a0_val = y0_scales[kt * 2 + 0];
__fp16 sum_a0_val = y0_scales[kt * 2 + 1];
__fp16 scale_a1_val = y1_scales[kt * 2 + 0];
__fp16 sum_a1_val = y1_scales[kt * 2 + 1];
HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val));
HVX_Vector v_sum_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a0_val));
HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val));
HVX_Vector v_sum_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a1_val));
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a0);
HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a1);
HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a1);
HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0);
HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128);
HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32);
HVX_Vector v_act_rep[8];
v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl);
v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl);
v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl);
v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl);
v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl);
v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl);
v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl);
v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl);
HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act_rep);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = vptr[8];
__fp16 scale_a_val = y_scales[kt];
HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val));
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size);
const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128);
HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128);
HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32);
HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32);
HVX_Vector v_act0_rep[8];
v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl);
v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl);
v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl);
v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl);
v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl);
v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl);
v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl);
v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl);
HVX_Vector v_act1_rep[8];
v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl);
v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl);
v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl);
v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl);
v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl);
v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl);
v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl);
v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl);
HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0_rep, v_act1_rep);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = vptr[8];
__fp16 scale_a0_val = y0_scales[kt];
__fp16 scale_a1_val = y1_scales[kt];
HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val));
HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val));
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128);
HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32);
HVX_Vector v_act_rep[8];
v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl);
v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl);
v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl);
v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl);
v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl);
v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl);
v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl);
v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl);
HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = vptr[4];
__fp16 scale_a_val = y_scales[kt];
HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val));
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size);
const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128);
HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128);
HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32);
HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32);
HVX_Vector v_act0_rep[8];
v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl);
v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl);
v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl);
v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl);
v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl);
v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl);
v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl);
v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl);
HVX_Vector v_act1_rep[8];
v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl);
v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl);
v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl);
v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl);
v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl);
v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl);
v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl);
v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl);
HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = vptr[4];
__fp16 scale_a0_val = y0_scales[kt];
__fp16 scale_a1_val = y1_scales[kt];
HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val));
HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val));
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128);
HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32);
HVX_Vector v_act_rep[8];
v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl);
v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl);
v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl);
v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl);
v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl);
v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl);
v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl);
v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl);
HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512);
HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23);
__fp16 scale_a_val = y_scales[kt];
HVX_Vector v_scale_a_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val));
HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32(v_scale_a_f16);
HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32);
HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
const uint32_t quants_size = hex_round_up(n, 128);
const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size);
const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
uint32_t block_idx = kt / 4;
uint32_t sub_idx = kt % 4;
HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128);
HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128);
HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32);
HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32);
HVX_Vector v_act0_rep[8];
v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl);
v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl);
v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl);
v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl);
v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl);
v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl);
v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl);
v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl);
HVX_Vector v_act1_rep[8];
v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl);
v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl);
v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl);
v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl);
v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl);
v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl);
v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl);
v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl);
HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512);
HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23);
__fp16 scale_a0_val = y0_scales[kt];
__fp16 scale_a1_val = y1_scales[kt];
HVX_Vector v_scale_a0_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val));
HVX_Vector v_scale_a1_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val));
HVX_VectorPair p_scale_a0_f32 = hvx_vec_f16_to_f32(v_scale_a0_f16);
HVX_VectorPair p_scale_a1_f32 = hvx_vec_f16_to_f32(v_scale_a1_f16);
HVX_Vector v_scale_a0 = Q6_V_lo_W(p_scale_a0_f32);
HVX_Vector v_scale_a1 = Q6_V_lo_W(p_scale_a1_f32);
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP32; uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum0 = Q6_V_vzero();
HVX_Vector rsum1 = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_sf = y[i];
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_sf = x0[i];
HVX_Vector r1_sf = x1[i];
HVX_Vector c0_sf = y0[i];
HVX_Vector c1_sf = y1[i];
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
}
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP32; uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
if (nloe) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
x_sf = Q6_V_vand_QV(bmask, x_sf);
y_sf = Q6_V_vand_QV(bmask, y_sf);
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
rsum = hvx_vec_reduce_sum_f32(rsum);
hvx_vec_store_u(&s[0], 4, rsum);
}
#undef HVX_OP_ADD_F32
#undef HVX_OP_MUL_F32
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum0_p = Q6_W_vzero();
HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = y[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
}
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_hf = x0[i];
HVX_Vector r1_hf = x1[i];
HVX_Vector c0_hf = y0[i];
HVX_Vector c1_hf = y1[i];
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); }
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16;
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16;
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
HVX_Vector x_hf = vx[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void hvx_tensor_add_f32_grid(
const struct htp_tensor * restrict dst,
const struct htp_tensor * restrict src2,
uint32_t start_row,
uint32_t end_row,
uint32_t start_col,
uint32_t end_col,
const struct fastdiv_values * div_ne11_12,
const struct fastdiv_values * div_ne11
) {
if (start_row >= end_row || start_col >= end_col) return;
const uint32_t nb1 = dst->nb[1];
const uint32_t ne11 = dst->ne[1];
const uint32_t ne12 = dst->ne[2];
const uint32_t ne11_12 = ne11 * ne12;
const bool is_broadcast1 = (src2->ne[1] == 1);
const bool is_broadcast2 = (src2->ne[2] == 1);
const bool is_broadcast3 = (src2->ne[3] == 1);
for (uint32_t r = start_row; r < end_row; r++) {
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
uint32_t i13 = fastdiv(r, div_ne11_12);
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
uint32_t i23 = is_broadcast3 ? 0 : i13;
uint32_t i22 = is_broadcast2 ? 0 : i12;
uint32_t i21 = is_broadcast1 ? 0 : i11;
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
float * dst_ptr = &dst_row[start_col];
const float * src2_ptr = &src2_row[start_col];
int remaining = end_col - start_col;
while (remaining >= 32) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
dst_ptr += 32;
src2_ptr += 32;
remaining -= 32;
}
if (remaining > 0) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
}
}
}