#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <assert.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <HAP_farf.h>
#include <HAP_compute_res.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "hex-dma.h"
#include "worker-pool.h"
#include "hvx-utils.h"
#include "hvx-dump.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "hmx-ops.h"
#include "hmx-utils.h"
#include "hmx-queue.h"
#include "hmx-profile.h"
static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
};
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
};
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
};
static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128,
8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128,
16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128,
24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128
};
#define HMX_X4X2_SCALES_PER_BLK 8
#define HMX_X4X2_DBLK_SIZE 16
#define HMX_X4X2_MXFP4_EBLK_SIZE 8
static inline void swap_ptr(void **p1, void **p2) {
void *t = *p1;
*p1 = *p2;
*p2 = t;
}
typedef struct {
uint8_t *dst;
const uint8_t *src;
dma_queue *dma;
size_t n_rows;
size_t src_stride; size_t dst_stride; size_t quant_off; size_t quant_width; size_t scale_off; size_t scale_width; } qweight_fetch_task_state_t;
static inline size_t get_x4x2_row_stride(int weight_type, int k) {
int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); case HTP_TYPE_Q8_0:
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); case HTP_TYPE_MXFP4:
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); default:
return 0;
}
}
static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) {
if (a != 0 && b > SIZE_MAX / a) return true;
*out = a * b;
return false;
}
static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) {
if (a > SIZE_MAX - b) return true;
*out = a + b;
return false;
}
static int hmx_compute_chunks(size_t vtcm_total,
size_t overhead,
size_t per_n_cost,
size_t per_m_cost,
size_t per_mn_cost,
int m,
int n,
size_t m_block_cost,
size_t n_block_cost,
size_t * m_chunk_out,
size_t * n_chunk_out,
size_t * total_out) {
if (m <= 0 || n <= 0) return -1;
if (vtcm_total <= overhead) return -1;
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
const size_t usable = vtcm_total - overhead;
size_t best_cost = SIZE_MAX;
size_t best_mn = 0;
size_t best_m = 0, best_n = 0;
const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS);
for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) {
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
if (n_fixed >= usable) goto next_nc;
if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
{
size_t remain = usable - n_fixed;
size_t mc = remain / mc_denom;
mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS);
mc = hex_smin(mc, (size_t)m);
if (mc == 0) {
goto next_nc;
}
size_t mblocks = ((size_t) m + mc - 1) / mc;
size_t nblocks = ((size_t) n + nc - 1) / nc;
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
size_t mn = mc * nc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_m = mc;
best_n = nc;
}
}
next_nc:
if (nc == HMX_FP16_TILE_N_COLS) break; }
if (best_m == 0 || best_n == 0) return -1;
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1;
if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1;
if (hmx_mul_overflow(best_m, best_n, &mn)) return -1;
if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1;
if (hmx_add_overflow(t0, t1, &total)) return -1;
if (hmx_add_overflow(total, t2, &total)) return -1;
if (hmx_add_overflow(total, overhead, &total)) return -1;
*m_chunk_out = best_m;
*n_chunk_out = best_n;
*total_out = total;
return 0;
}
void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride);
static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst,
const __fp16 *restrict vtcm_src,
int n_cols, int k) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
assert(k % HMX_FP16_TILE_N_COLS == 0);
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
for (int r = 0; r < n_cols; r += 2) {
int ct = r / HMX_FP16_TILE_N_ROWS; int local_r = r % HMX_FP16_TILE_N_ROWS; const bool next_row_valid = (r + 1) < n_cols;
HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) {
int kt = c / HMX_FP16_TILE_N_COLS;
int tile_idx = ct * n_k_tiles + kt;
__fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS;
HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c);
HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1);
}
}
}
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
const uint8_t *packed_32, bool upper_nibbles,
const __fp16 *scale, const HVX_Vector vlut_cvt) {
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
}
static inline void dequantize_x4x2_q4_0_x4groups_hvx(
const uint8_t *packed_128, bool upper_nibbles,
const __fp16 *scales_4, const HVX_Vector vlut_cvt,
HVX_Vector out[4]) {
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp); HVX_Vector v_hi = Q6_V_hi_W(vp);
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1]));
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3]));
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
out[0] = v_lo; out[1] = v_hi; }
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
const int8_t *quants_32, const __fp16 *scale) {
HVX_Vector vq = hvx_vmemu(quants_32);
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
}
typedef struct {
__fp16 v[8] __attribute__((aligned(16)));
} mxfp4_scales_t;
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
mxfp4_scales_t s;
HVX_Vector v = hvx_vmemu(e8m0_8);
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
vh = Q6_Vh_vasl_VhR(vh, 10);
hvx_vec_store_u(s.v, 16, vh);
return s;
}
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
return hvx_vec_splat_f16(scales.v[idx]);
}
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
bool upper_nibbles,
int sub_blk,
const HVX_Vector vlut_cvt,
mxfp4_scales_t scales) {
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
}
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
bool upper_nibbles,
int sub_blk_base,
const HVX_Vector vlut_cvt,
mxfp4_scales_t scales,
HVX_Vector out[4]) {
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp);
HVX_Vector v_hi = Q6_V_hi_W(vp);
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
mxfp4_extract_splat(scales, sub_blk_base + 1));
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
mxfp4_extract_splat(scales, sub_blk_base + 3));
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
out[0] = v_lo;
out[1] = Q6_V_vror_VR(v_lo, 64);
out[2] = v_hi;
out[3] = Q6_V_vror_VR(v_hi, 64);
}
static void dequantize_x4x2_weight_to_fp16_tiles_task(
__fp16 *restrict vtcm_dst,
const uint8_t *restrict vtcm_src,
int n_cols, int k_block,
size_t row_stride, int weight_type,
int start_tile, int end_tile) {
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
hvx_vmem(q4_0_to_fp16_lut);
const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
unsigned ct = (unsigned)start_tile / n_k_tiles; unsigned kt = (unsigned)start_tile % n_k_tiles; for (unsigned t = start_tile; t < end_tile; ) {
if (kt >= n_k_tiles) { kt = 0; ct++; }
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk_base >= 4);
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
+ sub_blk_base * (int)sizeof(__fp16);
__fp16 *tile_bases[4];
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
HVX_Vector v_off = v_scat_base;
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
HVX_Vector v0[2];
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
r0 = vtcm_src + row_offset; row_offset += row_stride;
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
t += 4; kt += 4;
continue;
}
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk_base >= 4);
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
__fp16 * tile_bases[4];
for (int g = 0; g < 4; g++) {
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
}
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t * r0 = vtcm_src + row0 * row_stride;
const uint8_t * r1 = vtcm_src + row1 * row_stride;
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
HVX_Vector v0[4], v1[4];
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
if (row1 < n_cols) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
} else {
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
}
for (int g = 0; g < 4; g++) {
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
}
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
for (int g = 0; g < 4; g++) {
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
}
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
for (int g = 0; g < 4; g++) {
(void) *(volatile HVX_Vector *) (tile_bases[g]);
}
t += 4;
continue;
}
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
if (is_q4) {
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
bool upper = (sub_blk >= 4);
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
HVX_Vector v_off = v_scat_base; unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q4_0_group_hvx(
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
: Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
(void) *(volatile HVX_Vector *)(tile_base);
} else if (weight_type == HTP_TYPE_MXFP4) {
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
bool upper = (sub_blk >= 4);
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t * r0 = vtcm_src + row0 * row_stride;
const uint8_t * r1 = vtcm_src + row1 * row_stride;
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
HVX_Vector v1;
if (row1 < n_cols) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
} else {
v1 = Q6_V_vzero();
}
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
(void) *(volatile HVX_Vector *) (tile_base);
} else {
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32;
int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32;
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t *r0 = vtcm_src + row0 * row_stride;
const uint8_t *r1 = vtcm_src + row1 * row_stride;
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx(
(const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q8_0_group_hvx(
(const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off))
: Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
(void) *(volatile HVX_Vector *)(tile_base);
}
++t; ++kt;
}
if (start_tile < end_tile) {
(void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
}
}
typedef struct {
__fp16 *dst;
const uint8_t *src;
int n_cols;
int k_block;
size_t row_stride;
int weight_type;
int n_tot_tiles;
int n_tiles_per_task;
int n_tasks;
} x4x2_dequantize_state_t;
static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task(
state->dst, state->src, state->n_cols, state->k_block,
state->row_stride, state->weight_type, start, end);
}
}
static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
struct htp_context *ctx, __fp16 *vtcm_dst,
const void *vtcm_src, int n_cols, int k_block,
size_t row_stride, int weight_type) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
size_t n_tot_tiles = n_col_tiles * n_k_tiles;
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
x4x2_dequantize_state_t state;
state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task;
state.n_tot_tiles = n_tot_tiles;
state.n_tiles_per_task = n_tiles_per_task;
state.dst = vtcm_dst;
state.src = (const uint8_t *)vtcm_src;
state.n_cols = n_cols;
state.k_block = k_block;
state.row_stride = row_stride;
state.weight_type = weight_type;
worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads);
}
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
Q6_bias_mxmem2_A((void *)scales);
for (int r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < n_col_tiles; ++c) {
Q6_mxclracc_hf();
const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
for (int k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
}
typedef struct {
__fp16 * output;
const __fp16 * activation;
const __fp16 * weight;
const __fp16 * scales;
uint32_t n_row_tiles;
uint32_t n_col_tiles;
uint32_t n_dot_tiles;
} hmx_matmul_job_t;
static void hmx_matmul_worker_fn(void * data) {
hmx_matmul_job_t * job = (hmx_matmul_job_t *) data;
FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles);
core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles);
}
static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
__fp16 * output,
const __fp16 * activation,
const __fp16 * weight,
const __fp16 * scales,
int n_row_tiles,
int n_col_tiles,
int n_dot_tiles) {
job->output = output;
job->activation = activation;
job->weight = weight;
job->scales = scales;
job->n_row_tiles = n_row_tiles;
job->n_col_tiles = n_col_tiles;
job->n_dot_tiles = n_dot_tiles;
}
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS;
const HVX_Vector one = hvx_vec_splat_f16(1.0);
for (size_t r = 0; r < n_rows; r += 2) {
const size_t r0 = r / HMX_FP16_TILE_N_ROWS;
const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
float *output_row_base = dst + r * n;
#pragma unroll(4)
for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
const size_t c0 = c / HMX_FP16_TILE_N_COLS;
const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS;
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0);
volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n);
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (r + 1 < n_rows) {
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
}
}
}
}
typedef struct {
const __fp16 *vtcm_src;
float *dst;
int n_tasks;
int n_tot_chunks;
int n_chunks_per_task;
int n_cols;
int n; } output_transfer_task_state_t;
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
int chunk_idx = task_id * st->n_chunks_per_task;
size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task);
float *dst = st->dst + chunk_idx * st->n;
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
}
}
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
int n_rows, int n_cols, int n) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
size_t n_tot_chunks = n_rows;
size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS;
output_transfer_task_state_t state;
state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task;
state.n_tot_chunks = n_tot_chunks;
state.n_chunks_per_task = n_chunks_per_task;
state.dst = dst;
state.vtcm_src = vtcm_src;
state.n_cols = n_cols;
state.n = n;
worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads);
}
static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) {
return params->ne02 > 0 ? params->ne12 / params->ne02 : 1;
}
static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) {
return params->ne03 > 0 ? params->ne13 / params->ne03 : 1;
}
static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
int dst_b2, int dst_b3) {
const int r2 = hmx_matmul_batch_r2(params);
const int r3 = hmx_matmul_batch_r3(params);
return (const __fp16 *) ((const uint8_t *) params->permuted_weight +
(size_t) (dst_b2 / r2) * params->src0_nb2 +
(size_t) (dst_b3 / r3) * params->src0_nb3);
}
static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
int dst_b2, int dst_b3) {
return (const float *) ((const uint8_t *) params->activation +
(size_t) dst_b2 * params->src1_nb2 +
(size_t) dst_b3 * params->src1_nb3);
}
static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
int dst_b2, int dst_b3) {
return (float *) ((uint8_t *) params->dst +
(size_t) dst_b2 * params->dst_nb2 +
(size_t) dst_b3 * params->dst_nb3);
}
static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx,
const hmx_matmul_w16a32_batched_params_t *params) {
int ret = 0;
for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) {
for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) {
ret = hmx_mat_mul_permuted_w16a32(ctx,
hmx_matmul_dst_batch_ptr(params, b2, b3),
hmx_matmul_activation_batch_ptr(params, b2, b3),
hmx_matmul_weight_batch_ptr(params, b2, b3),
params->m, params->k, params->n,
params->act_stride, params->weight_stride);
}
}
return ret;
}
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) {
if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; }
if (!params->m || !params->k || !params->n) { return -1; }
if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; }
if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; }
if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; }
if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; }
if (!hex_is_aligned(params->dst, VLEN) ||
!hex_is_aligned(params->activation, VLEN) ||
!hex_is_aligned(params->permuted_weight, VLEN)) {
return -1;
}
const int group_size = hmx_matmul_batch_r2(params);
if (group_size <= 1) {
FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size);
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
}
const size_t vtcm_budget = ctx->vtcm_size;
const size_t vec_dot_size = params->k * sizeof(__fp16);
const bool use_dma_activation = (params->act_stride > params->k);
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0;
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
if (hmx_compute_chunks(vtcm_budget, 256,
3 * vec_dot_size,
group_size * vec_dot_size + f32_scratch_per_m,
sizeof(__fp16), params->m, params->n,
(size_t) params->n,
(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
}
const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t f32_scratch_size = use_dma_activation
? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0;
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL;
if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) {
FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__);
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));
FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu",
__func__, params->m, params->k, params->n, group_size, params->ne13,
m_chunk_n_rows, n_chunk_n_cols,
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (int b3 = 0; b3 < params->ne13; ++b3) {
for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) {
const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3);
for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) {
const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows);
const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS);
TIMER_START(activation_load);
for (int g = 0; g < group_size; ++g) {
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
if (use_dma_activation) {
const size_t row_bytes = (size_t) params->k * sizeof(float);
const size_t stride_bytes = (size_t) params->act_stride * sizeof(float);
dma_queue_push(ctx->dma[0],
dma_make_ptr(vtcm_f32_act, activation_chunk),
row_bytes, stride_bytes, row_bytes, n_rows);
dma_queue_pop(ctx->dma[0]);
transfer_activation_chunk_threaded(ctx, vtcm_act_g,
vtcm_f32_act, (int) n_rows,
params->k, params->k);
} else {
transfer_activation_chunk_threaded(ctx, vtcm_act_g,
activation_chunk, (int) n_rows,
params->k, params->act_stride);
}
}
TIMER_STOP(activation_load);
void *buf_curr = vtcm_scratch0;
void *buf_next = vtcm_scratch1;
{
const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols);
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
}
for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) {
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
dma_queue_pop(ctx->dma[0]);
const size_t nc_next = nc + n_chunk_n_cols;
if (nc_next < (size_t) params->n) {
const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols);
const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
}
interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k);
swap_ptr(&buf_curr, &buf_next);
}
TIMER_STOP(weight_load);
for (int g = 0; g < group_size; ++g) {
TIMER_START(hmx_core);
{
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
params->k / 32);
}
TIMER_STOP(hmx_core);
TIMER_START(output_store);
{
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride);
}
TIMER_STOP(output_store);
}
}
}
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
params->m, params->k, params->n, group_size);
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
#endif
return 0;
}
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
const __fp16 *restrict permuted_weight, int m, int k, int n,
int act_stride, int weight_stride) {
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
if (act_stride < k || weight_stride < k) { return -1; }
if (k % 32 != 0 || n % 32 != 0) { return -1; }
if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) {
return -1;
}
const size_t vtcm_budget = ctx->vtcm_size;
const size_t vec_dot_size = k * sizeof(__fp16);
const bool use_dma_activation = (act_stride > k);
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0;
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
if (hmx_compute_chunks(vtcm_budget,
256,
3 * vec_dot_size, vec_dot_size + f32_scratch_per_m, sizeof(__fp16), m, n,
(size_t) n,
(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
return -1;
}
const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t f32_scratch_size = use_dma_activation
? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0;
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL;
if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) {
FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
return -1;
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));
FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu",
__func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
TIMER_START(activation_load);
{
const float *activation_chunk = activation + mr * act_stride;
if (use_dma_activation) {
const size_t row_bytes = (size_t) k * sizeof(float);
const size_t stride_bytes = (size_t) act_stride * sizeof(float);
dma_queue_push(ctx->dma[0],
dma_make_ptr(vtcm_f32_act, activation_chunk),
row_bytes, stride_bytes, row_bytes, n_rows);
dma_queue_pop(ctx->dma[0]);
transfer_activation_chunk_threaded(ctx, vtcm_activation,
vtcm_f32_act, n_rows, k, k);
} else {
transfer_activation_chunk_threaded(ctx, vtcm_activation,
activation_chunk, n_rows, k, act_stride);
}
}
TIMER_STOP(activation_load);
const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16);
const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16);
void *buf_curr = vtcm_scratch0;
void *buf_next = vtcm_scratch1;
{
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
}
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
dma_queue_pop(ctx->dma[0]);
const size_t nc_next = nc + n_chunk_n_cols;
if (nc_next < n) {
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
}
interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k);
swap_ptr(&buf_curr, &buf_next);
}
TIMER_STOP(weight_load);
TIMER_START(hmx_core);
{
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
}
TIMER_STOP(hmx_core);
TIMER_START(output_store);
{
float *output = dst + (mr * n + nc);
transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n);
}
TIMER_STOP(output_store);
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n);
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
{
size_t weight_size = (size_t)k * n * sizeof(__fp16);
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
}
#endif
return 0;
}
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
int k, int n, int w_type);
#define FALLBACK_TO_STANDARD 1
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
const uint8_t *restrict permuted_weight, int m, int k, int n,
int weight_type) {
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
if (k % 32 != 0 || n % 32 != 0) { return -1; }
if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) {
return -1;
}
if (m >= 128 && k > n && n > 1024) {
int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
if (rc != FALLBACK_TO_STANDARD) {
return rc; }
FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n);
}
size_t row_stride = get_x4x2_row_stride(weight_type, k);
if (row_stride == 0) {
return -1;
}
FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type);
const size_t vtcm_budget = ctx->vtcm_size;
const size_t vec_dot_size = k * sizeof(__fp16);
const bool use_pipeline = (m >= 128) && (k <= n);
size_t per_n_cost, per_mn_cost;
if (use_pipeline) {
per_n_cost = row_stride + 2 * vec_dot_size; per_mn_cost = 2 * sizeof(__fp16); } else {
per_n_cost = vec_dot_size + 2 * row_stride; per_mn_cost = sizeof(__fp16); }
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
if (hmx_compute_chunks(vtcm_budget, 256, per_n_cost, vec_dot_size, per_mn_cost, m, n,
(size_t) n * 3,
(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)",
__func__, m, k, n, use_pipeline, vtcm_budget);
return -1;
}
const size_t weight_area_size = hex_align_up(
n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE);
const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
const size_t output_area_size = hex_align_up(
m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
size_t scratch0_size, scratch1_size, scratch2_size;
if (use_pipeline) {
scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); scratch1_size = scratch0_size; scratch2_size = output_area_size; } else {
scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); scratch1_size = scratch0_size; scratch2_size = 0; }
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size);
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size);
void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL;
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) {
FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
return -1;
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu",
__func__, m, k, n, weight_type, use_pipeline,
m_chunk_n_rows, n_chunk_n_cols,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu",
use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
if (!use_pipeline) {
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
TIMER_START(activation_load);
{
const float *activation_chunk = activation + mr * k;
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
}
TIMER_STOP(activation_load);
void *buf_curr = vtcm_scratch0;
void *buf_next = vtcm_scratch1;
{
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
}
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
dma_queue_pop(ctx->dma[0]);
const size_t nc_next = nc + n_chunk_n_cols;
if (nc_next < n) {
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
}
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type);
swap_ptr(&buf_curr, &buf_next);
}
TIMER_STOP(weight_load);
TIMER_START(hmx_core);
{
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
}
TIMER_STOP(hmx_core);
TIMER_START(output_store);
{
float *output = dst + (mr * n + nc);
transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n);
}
TIMER_STOP(output_store);
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
} else {
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
hmx_matmul_job_t job_slots[2];
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
void *vtcm_qweight = vtcm_weight;
void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };
const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
{
const uint8_t *qweight_chunk_A0 = permuted_weight;
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
}
{
const float *activation_chunk = activation + mr * k;
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
}
{
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
if (1 < n_chunk_cnt) {
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
}
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
if (1 < n_chunk_cnt) {
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
}
}
for (int i = 0; i < n_chunk_cnt; ++i) {
const size_t nc = i * n_chunk_n_cols;
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
const size_t nc_p2 = nc + 2 * n_chunk_n_cols;
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
if (i + 2 < n_chunk_cnt) {
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
}
hmx_queue_pop(ctx->hmx_queue);
if (i + 1 < n_chunk_cnt) {
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
}
float *output_chunk = dst + (mr * n + nc);
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
if (i + 2 < n_chunk_cnt) {
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
}
}
}
hmx_queue_suspend(ctx->hmx_queue);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline);
if (!use_pipeline) {
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
size_t weight_size = (size_t)n * row_stride;
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
}
#endif
return 0;
}
void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile,
int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) {
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
Q6_bias_mxmem2_A((void *)col_scales);
const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS;
for (size_t i = 0; i < n_row_tiles; ++i) {
const __fp16 *row_base = a + i * dot_tile_stride;
__fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS;
for (size_t j = 0; j < n_col_tiles; ++j) {
Q6_mxclracc_hf();
const __fp16 *col_tiles = b + j * dot_tile_stride;
const __fp16 *row_tiles = row_base;
__fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS;
if (!zero_init) {
Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
}
for (int k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(accum_tile, 0);
}
}
}
static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows,
int k_block, int k_stride) {
for (int r = 0; r < n_rows; r += 2) {
int r0 = r / HMX_FP16_TILE_N_ROWS; int r1 = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < n_rows;
const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
for (int c = 0; c < k_block; c += 32) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
int c0 = c / HMX_FP16_TILE_N_COLS; int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;
HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
tile[r1 / 2] = v_out;
}
}
}
typedef struct {
__fp16 *dst;
const float *src;
int n_tasks;
int n_tot_chunks;
int n_chunks_per_task;
int k_block;
int k_stride;
} activation_transfer_task_state_t;
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
int chunk_idx = task_id * st->n_chunks_per_task;
size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task);
__fp16 *dst = st->dst + chunk_idx * st->k_block;
const float *src = st->src + chunk_idx * st->k_stride;
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
}
}
void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) {
assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0);
assert(VLEN == 32 * sizeof(float));
size_t n_tot_chunks = n_rows;
size_t n_chunks_per_task = 32;
activation_transfer_task_state_t state;
state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task;
state.n_tot_chunks = n_tot_chunks;
state.n_chunks_per_task = n_chunks_per_task;
state.dst = dst;
state.src = src;
state.k_block = k_block;
state.k_stride = k_stride;
worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads);
}
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w,
int m, int k, int n, int weight_type) {
const size_t row_stride = get_x4x2_row_stride(weight_type, k);
if (row_stride == 0) {
return -1;
}
const size_t vtcm_budget = ctx->vtcm_size;
const size_t K_BLOCK_SIZE = 1024;
const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE;
if (k_iters_check <= 1) {
FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k);
return FALLBACK_TO_STANDARD;
}
const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE);
const size_t per_m = K_BLOCK_SIZE * sizeof(float) + K_BLOCK_SIZE * sizeof(__fp16); const size_t per_n = sub_row_stride_alloc + K_BLOCK_SIZE * sizeof(__fp16); const size_t per_mn = sizeof(__fp16); const size_t align_margin = 4 * HMX_FP16_TILE_SIZE;
const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin;
size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used;
const size_t m_block_cost = (size_t) n * 3;
const size_t n_block_cost = (size_t) m * 2;
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
&N_BLOCK_SIZE, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
return -1;
}
const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE);
const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE);
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256;
if (total_vtcm > vtcm_budget) {
FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm,
vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE);
return -1;
}
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size);
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size);
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size);
uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz);
uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz);
__fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE);
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type,
M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
{
HVX_Vector v;
v = Q6_V_vzero();
v = Q6_Vw_vinsert_VwR(v, 0x3c000000);
v = Q6_V_vror_VR(v, VLEN - 4);
v = Q6_Vw_vinsert_VwR(v, 0x00003c00);
for (int i = 0; i < 16; ++i) {
((HVX_Vector *) vtcm_eye_tile)[i] = v;
v = Q6_V_vror_VR(v, VLEN - 8);
}
}
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));
TIMER_DEFINE(fetch);
TIMER_DEFINE(act_load);
TIMER_DEFINE(wt_dequant);
TIMER_DEFINE(core);
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) {
size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE);
for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) {
size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE);
const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS);
const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);
for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
TIMER_START(fetch);
{
const float *activation_block = x + mr * k + kk;
dma_queue_push(ctx->dma[0],
dma_make_ptr(vtcm_scratch1, activation_block),
k_blk_sz * sizeof(float),
k * sizeof(float),
k_blk_sz * sizeof(float),
m_blk_sz);
}
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
{
qweight_fetch_task_state_t s;
const int blk_start = kk / QK_Q4_0x4x2;
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
const int scale_blk_size =
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
s.dst = vtcm_scratch0;
s.src = w + nc * row_stride;
s.n_rows = n_blk_sz;
s.src_stride = row_stride;
s.dst_stride = sub_row_stride;
s.quant_off =
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
s.quant_width =
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
s.scale_off = full_qrow + blk_start * scale_blk_size;
s.scale_width = nb_sub * scale_blk_size;
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
s.dst_stride, s.src_stride, s.quant_width, s.n_rows);
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off),
s.dst_stride, s.src_stride, s.scale_width, s.n_rows);
}
TIMER_STOP(fetch);
TIMER_START(act_load);
{
dma_queue_pop(ctx->dma[0]); transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz);
}
TIMER_STOP(act_load);
TIMER_START(wt_dequant);
{
dma_queue_pop(ctx->dma[0]);
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0,
n_blk_sz, k_blk_sz, sub_row_stride, weight_type);
}
TIMER_STOP(wt_dequant);
TIMER_START(core);
{
core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles,
n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0);
}
TIMER_STOP(core);
}
{
float *output_block = out + (mr * n + nc);
transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n);
}
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us",
TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core));
#endif
return 0;
}