#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <assert.h>
#include <HAP_compute_res.h>
#include <HAP_farf.h>
#include <math.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "hex-dma.h"
#include "hmx-profile.h"
#include "hmx-queue.h"
#include "hmx-utils.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "hvx-dump.h"
#include "hvx-reduce.h"
#include "hvx-utils.h"
#include "vtcm-utils.h"
#include "worker-pool.h"
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096);
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
return q_tile_size * 1 + o_tile_size * 2 + k_dma_size * 2 + v_dma_size * 2 + k_tile_size * 1 + v_tile_size * 1 + s_tile_size * 2 + d_tile_size * 1 + col_vec_size * 4 + row_vec_size * 2 * n_threads + m_buf_size * 1 + slopes_size + 256 * 2; }
static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) {
const HVX_Vector zero_v = Q6_V_vzero();
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800);
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v);
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v);
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00));
y = Q6_Vhf_equals_Vqf16(y);
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
}
#define FA_MIN_KV_BLOCKS 3
static int hmx_fa_find_chunk_size(size_t * Br_out,
size_t * Bc_out,
size_t gqa_factor,
size_t DK,
size_t DV,
size_t qo_len,
size_t kv_len,
size_t vtcm_budget,
size_t n_threads) {
const size_t T = HMX_FP16_TILE_N_ROWS; const size_t br_unit = hmx_ceil_div(T, gqa_factor);
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; const size_t fp16 = sizeof(__fp16);
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; const size_t per_gbr2 = fp16; const size_t per_bc =
3 * (DK + DV) * fp16 + 2 * n_threads * fp16; const size_t per_gbr_bc = 2 * fp16;
const size_t overhead = 256 * 2 + 13 * 4096;
if (vtcm_budget <= overhead) {
return -1;
}
const size_t usable = vtcm_budget - overhead;
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
const size_t c_q_fixed = 1400; const size_t c_iter_fixed = 200;
size_t best_cost = SIZE_MAX, best_mn = 0;
size_t best_Br = 0, best_Bc = 0;
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
const size_t g_br = hex_align_up(gqa_factor * Br, T);
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
if (gbr_cost >= usable) {
if (Br == br_unit) {
break;
}
continue;
}
const size_t remain = usable - gbr_cost;
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16;
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) {
Bc -= bc_unit;
}
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
const size_t q_blocks = (qo_len + Br - 1) / Br;
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
const size_t mn = Br * Bc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_Br = Br;
best_Bc = Bc;
}
if (Br == br_unit) {
break;
}
}
if (best_Br == 0) {
return -1;
}
*Br_out = best_Br;
*Bc_out = best_Bc;
return 0;
}
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
0 * 136, 0 * 136 + 6,
1 * 136, 1 * 136 + 6,
2 * 136, 2 * 136 + 6,
3 * 136, 3 * 136 + 6,
4 * 136, 4 * 136 + 6,
5 * 136, 5 * 136 + 6,
6 * 136, 6 * 136 + 6,
7 * 136, 7 * 136 + 6,
8 * 136, 8 * 136 + 6,
9 * 136, 9 * 136 + 6,
10 * 136, 10 * 136 + 6,
11 * 136, 11 * 136 + 6,
12 * 136, 12 * 136 + 6,
13 * 136, 13 * 136 + 6,
14 * 136, 14 * 136 + 6,
15 * 136, 15 * 136 + 6,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
};
struct hmx_fa_context {
const struct htp_ops_context * octx;
bool use_pipeline; uint32_t n_threads;
float scale;
float max_bias;
float logit_softcap;
uint32_t n_head_log2;
float m0, m1;
uint32_t DK, DV;
uint32_t n_kv; uint32_t n_kv_heads; uint32_t n_heads; uint32_t G; uint32_t n_kv_blocks;
uint32_t neq1;
bool is_q_fp32;
bool is_dst_fp32;
uint32_t Br; uint32_t Bc;
uint32_t g_br;
__fp16 * vtcm_q_tiles; __fp16 * vtcm_o_tiles[2]; __fp16 * vtcm_k_fp16[2]; __fp16 * vtcm_v_fp16[2]; __fp16 * vtcm_k_tiles; __fp16 * vtcm_v_tiles; __fp16 * vtcm_s_tiles; __fp16 * vtcm_p_tiles; __fp16 * vtcm_d_tiles; HVX_Vector * vtcm_m_vec; HVX_Vector * vtcm_l_vec; HVX_Vector * vtcm_s_rowmax; HVX_Vector * vtcm_p_rowsum; HVX_Vector * vtcm_row_bufs; uint8_t * vtcm_hmx_scales_id; uint8_t * vtcm_hmx_scales_qk; __fp16 * vtcm_mask_buf; __fp16 * vtcm_slopes; size_t row_buf_stride; size_t mask_buf_row_stride; bool mask_broadcast; };
typedef struct {
struct hmx_fa_context * factx;
int kv_rows;
size_t src_stride;
size_t buf_idx;
} fa_k_int_args_t;
static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) {
fa_k_int_args_t * args = (fa_k_int_args_t *) data;
struct hmx_fa_context * factx = args->factx;
const int total_rows = args->kv_rows;
const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); const int start = i * rows_per_t;
const int end = hex_smin(start + rows_per_t, total_rows);
if (start >= total_rows) {
return;
}
hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK,
(int) args->src_stride, start, end);
}
static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) {
worker_pool_context_t wp = factx->octx->ctx->worker_pool;
fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx };
if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) {
worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads);
} else {
fa_k_interleave_thread(1, 0, &args);
}
}
typedef struct {
struct hmx_fa_context * factx;
int kv_rows;
size_t src_stride;
size_t buf_idx;
size_t n_col_tiles;
} fa_v_int_args_t;
static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) {
fa_v_int_args_t * args = (fa_v_int_args_t *) data;
struct hmx_fa_context * factx = args->factx;
const int total_rows = args->kv_rows;
const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2);
const int start = i * rows_per_t;
const int end = hex_smin(start + rows_per_t, total_rows);
if (start >= total_rows) {
return;
}
hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
(int) args->src_stride, (int) args->n_col_tiles, start, end);
}
static void fa_phase_v_interleave(struct hmx_fa_context * factx,
int kv_rows,
size_t src_stride,
size_t buf_idx,
size_t n_col_tiles) {
worker_pool_context_t wp = factx->octx->ctx->worker_pool;
fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles };
if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) {
worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads);
} else {
fa_v_interleave_thread(1, 0, &args);
}
}
typedef struct {
struct hmx_fa_context * factx;
const struct htp_tensor * q;
uint32_t q_start;
uint32_t kv_head;
uint32_t ib3;
size_t n_rows_g;
} fa_q_load_args_t;
static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
fa_q_load_args_t * args = (fa_q_load_args_t *) data;
struct hmx_fa_context * factx = args->factx;
const size_t n_rows_g = args->n_rows_g;
const size_t G = factx->G;
const size_t DK = factx->DK;
const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2);
const size_t start = (size_t) i * rows_per_t;
const size_t end = hex_smin(start + rows_per_t, n_rows_g);
if (start >= n_rows_g) {
return;
}
const struct htp_tensor * q = args->q;
const uint32_t q_start = args->q_start;
const uint32_t kv_head = args->kv_head;
const uint32_t ib3 = args->ib3;
for (size_t r = start; r < end; r += 2) {
const bool next_row_valid = (r + 1) < n_rows_g;
const size_t q_idx0 = (r + 0) / G;
const size_t h_idx0 = (r + 0) % G;
const size_t q_idx1 = (r + 1) / G;
const size_t h_idx1 = (r + 1) % G;
const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] +
(kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3];
const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] +
(kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) :
NULL;
size_t r0 = r / HMX_FP16_TILE_N_ROWS;
size_t r1 = r % HMX_FP16_TILE_N_ROWS;
__fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK;
if (factx->is_q_fp32) {
const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0;
const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL;
for (uint32_t d = 0; d < DK / 32; ++d) {
HVX_Vector v0 = pv_in0[d];
HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero();
HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1);
HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS);
out_tile[r1 / 2] = v_hf;
}
} else {
const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0;
const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL;
for (uint32_t d = 0; d < DK / 64; ++d) {
HVX_Vector v0 = pv_in0[d];
HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero();
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
__fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2;
HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2;
HVX_Vector * pv_out1 = pv_out0 + 16;
*pv_out0 = Q6_V_lo_W(vp);
*pv_out1 = Q6_V_hi_W(vp);
}
}
}
}
static void fa_phase_q_load(struct hmx_fa_context * factx,
const struct htp_tensor * q,
uint32_t q_start,
uint32_t kv_head,
uint32_t ib3,
size_t n_rows_g) {
worker_pool_context_t wp = factx->octx->ctx->worker_pool;
fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g };
if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) {
worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads);
} else {
fa_q_load_thread(1, 0, &args);
}
}
typedef struct {
struct hmx_fa_context * factx;
const struct htp_tensor * dst;
const __fp16 * o_tile_src;
uint32_t q_start;
uint32_t kv_head;
uint32_t ib3;
size_t n_rows_g;
} fa_o_store_args_t;
static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
fa_o_store_args_t * args = (fa_o_store_args_t *) data;
struct hmx_fa_context * factx = args->factx;
const size_t n_rows_g = args->n_rows_g;
const size_t G = factx->G;
const size_t DV = factx->DV;
const size_t rows_per_t = hmx_ceil_div(n_rows_g, n);
const size_t start = (size_t) i * rows_per_t;
const size_t end = hex_smin(start + rows_per_t, n_rows_g);
if (start >= n_rows_g) {
return;
}
const struct htp_tensor * dst = args->dst;
const __fp16 * o_tile_src = args->o_tile_src;
const uint32_t q_start = args->q_start;
const uint32_t kv_head = args->kv_head;
const uint32_t ib3 = args->ib3;
for (size_t r = start; r < end; ++r) {
const size_t q_idx = r / G;
const size_t h_idx = r % G;
uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] +
(q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3];
size_t r0 = r / HMX_FP16_TILE_N_ROWS;
size_t r1 = r % HMX_FP16_TILE_N_ROWS;
const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV;
if (factx->is_dst_fp32) {
float * out = (float *) dst_row;
for (uint32_t d = 0; d < DV / 32; ++d) {
const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS);
HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]);
if (r1 % 2 == 0) {
*(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp);
} else {
*(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp);
}
}
} else {
__fp16 * out = (__fp16 *) dst_row;
for (uint32_t d = 0; d < DV / 64; ++d) {
const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2;
const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2;
const HVX_Vector * pv_in1 = pv_in0 + 16;
HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2);
if (r1 % 2 == 0) {
*(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp);
} else {
*(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp);
}
}
}
}
}
static void fa_phase_o_store(struct hmx_fa_context * factx,
const struct htp_tensor * dst,
const __fp16 * o_tile_src,
uint32_t q_start,
uint32_t kv_head,
uint32_t ib3,
size_t n_rows_g) {
worker_pool_context_t wp = factx->octx->ctx->worker_pool;
fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g };
if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) {
worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads);
} else {
fa_o_store_thread(1, 0, &args);
}
}
typedef struct {
struct hmx_fa_context * factx;
size_t kv_rows;
size_t n_rows_g;
size_t n_col_tiles;
size_t n_tiles_per_bc;
size_t n_row_tiles;
size_t n_row_tiles_g_br;
uint32_t Bc;
uint32_t G;
uint32_t kv_head;
uint32_t kv_start;
uint32_t q_start;
uint32_t ib3;
bool has_alibi;
__fp16 * slopes;
const struct htp_tensor * mask;
const __fp16 * mask_vtcm; size_t mask_vtcm_row_stride; } fa_softmax_args_t;
static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
fa_softmax_args_t * args = (fa_softmax_args_t *) data;
struct hmx_fa_context * factx = args->factx;
const size_t n_rows_g = args->n_rows_g;
const size_t kv_rows = args->kv_rows;
const size_t Bc = args->Bc;
const size_t G = args->G;
const size_t n_tiles_per_bc = args->n_tiles_per_bc;
const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64);
const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n);
const size_t vec_start = i * vecs_per_t;
const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt);
if (vec_start >= n_row_vec_cnt) {
return;
}
const size_t row_buf_stride = factx->row_buf_stride;
HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride;
HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride;
const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff);
for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) {
HVX_Vector rowmax_acc_v = v_neg_inf;
HVX_Vector rowsum_acc_v = Q6_V_vzero();
HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx];
for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) {
int r = r_vec_idx * 64 + r_vec_off;
if (r >= (int) hex_align_up(n_rows_g, 2)) {
break;
}
int r0 = r / HMX_FP16_TILE_N_ROWS;
int r1 = r % HMX_FP16_TILE_N_ROWS;
const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc;
__fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc;
HVX_Vector * pv_row_buf0 = my_row_buf0;
HVX_Vector * pv_row_buf1 = my_row_buf1;
for (size_t c = 0; c < kv_rows; c += 64) {
const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2;
const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2;
const HVX_Vector * pv_s_in1 = pv_s_in0 + 16;
HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2);
*pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row);
*pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row);
}
if (factx->logit_softcap != 0.0f) {
float cap = factx->logit_softcap;
#ifdef HMX_FA_USE_EXP2_HF
cap *= 1.44269504f; #endif
const HVX_Vector v_cap = hvx_vec_splat_f32(cap);
for (size_t c = 0; c < kv_rows; c += 64) {
size_t ci = c / 64;
HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]);
HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32));
HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32));
t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap));
t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap));
my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi);
HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]);
HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32));
HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32));
t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap));
t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap));
my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi);
}
}
HVX_Vector v_slope0, v_slope1;
if (args->has_alibi) {
v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]);
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero();
}
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00);
HVX_Vector v_s_rowmax0 = v_neg_inf;
HVX_Vector v_s_rowmax1 = v_neg_inf;
for (size_t c = 0; c < kv_rows; c += 64) {
size_t ci = c / 64;
const size_t ne = hex_smin(kv_rows - c, 64);
HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16));
if (args->mask) {
HVX_Vector v_mask0, v_mask1;
if (args->mask_vtcm) {
const size_t qi0 = (r + 0) / G;
v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c);
v_mask1 = v_neg_inf;
if (r + 1 < (int) n_rows_g) {
const size_t qi1 = (r + 1) / G;
if (qi1 == qi0) {
v_mask1 = v_mask0; } else {
v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c);
}
}
} else {
const struct htp_tensor * mask = args->mask;
const size_t q_idx0 = args->q_start + ((r + 0) / G);
const size_t h_idx0 = args->kv_head * G + (r + 0) % G;
const uint32_t im2_0 = h_idx0 % mask->ne[2];
const uint32_t im3_0 = args->ib3 % mask->ne[3];
const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] +
im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c;
v_mask0 = *(const HVX_UVector *) m0_ptr;
v_mask1 = v_neg_inf;
if (r + 1 < (int) n_rows_g) {
const size_t q_idx1 = args->q_start + ((r + 1) / G);
if (q_idx1 == q_idx0) {
v_mask1 = v_mask0;
} else {
const size_t h_idx1 = args->kv_head * G + (r + 1) % G;
const uint32_t im2_1 = h_idx1 % mask->ne[2];
const uint32_t im3_1 = args->ib3 % mask->ne[3];
const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] +
im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c;
v_mask1 = *(const HVX_UVector *) m1_ptr;
}
}
}
HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep);
HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep);
if (args->has_alibi) {
HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0);
HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1);
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf);
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf);
} else {
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf);
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf);
}
} else {
if (ne < 64) {
my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf);
my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf);
}
}
v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]);
v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]);
}
v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0);
v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1);
HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)));
HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)));
HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0);
HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1);
{
HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2);
HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2);
HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2);
HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start);
HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid);
rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v);
rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v);
}
const HVX_Vector v_zero = Q6_V_vzero();
HVX_Vector v_p_rowsum0 = v_zero;
HVX_Vector v_p_rowsum1 = v_zero;
#ifdef HMX_FA_USE_EXP2_HF
for (size_t c = 0; c < kv_rows; c += 64) {
size_t ci = c / 64;
HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0);
HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1);
HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0));
HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1));
#else
for (size_t c = 0; c < kv_rows; c += 64) {
size_t ci = c / 64;
HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0);
HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1);
HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0));
HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0));
HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0));
HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi);
HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1));
HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1));
HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1));
HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi);
#endif
__fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2;
HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2;
HVX_Vector * pv_p_out1 = pv_p_out0 + 16;
HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2);
*pv_p_out0 = Q6_V_lo_W(vp_p_dual);
*pv_p_out1 = Q6_V_hi_W(vp_p_dual);
HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf);
HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf);
v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0)));
v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1)));
}
HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0));
HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1));
{
HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf);
HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf);
HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2);
HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2);
HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2);
HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start);
HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid);
rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v);
rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v);
}
}
factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v;
factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v;
}
}
static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx,
size_t n_rows_g,
size_t n_row_tiles,
size_t n_row_tiles_g_br) {
HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax;
const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64);
for (size_t i = 0; i < n_row_vec_cnt; ++i) {
HVX_Vector v_m_prev = factx->vtcm_m_vec[i];
HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]);
HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr);
#ifdef HMX_FA_USE_EXP2_HF
HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff));
#else
HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff));
HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff));
HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff));
HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi);
#endif
HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff);
v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]);
factx->vtcm_m_vec[i] = v_m_curr;
factx->vtcm_l_vec[i] = v_l_curr;
mvec_exp_m_diff[i] = v_exp_m_diff;
}
const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets;
const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16));
for (size_t i = 0; i < n_row_tiles; ++i) {
const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64);
__fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS;
Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content);
__asm__ __volatile__("" ::: "memory");
(void) *(volatile HVX_Vector *) out_base;
}
}
static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx,
size_t n_row_tiles,
size_t n_row_tiles_g_br) {
const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets;
const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16));
const HVX_Vector one = hvx_vec_splat_f32(1.0f);
HVX_Vector v_content = Q6_V_vzero();
for (size_t i = 0; i < n_row_tiles; ++i) {
if ((i % 2) == 0) {
HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]);
HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf);
HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l))));
HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l))));
v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi);
} else {
v_content = Q6_V_vror_VR(v_content, 64);
}
__fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS;
Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content);
__asm__ __volatile__("" ::: "memory");
(void) *(volatile HVX_Vector *) out_base;
}
}
static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx,
fa_softmax_args_t * sargs,
size_t n_row_tiles,
size_t n_row_tiles_g_br) {
worker_pool_context_t wp = factx->octx->ctx->worker_pool;
const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64);
if (factx->n_threads > 1 && n_row_vec_cnt >= 2) {
uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt);
worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use);
} else {
fa_softmax_thread(1, 0, sargs);
}
fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br);
}
typedef struct {
const __fp16 * q_tiles;
const __fp16 * k_tiles;
__fp16 * s_tiles;
size_t n_row_tiles;
size_t n_col_tiles;
size_t n_dot_tiles; size_t n_tiles_per_bc;
uint8_t * hmx_scales;
} hmx_fa_qk_job_t;
static void hmx_fa_qk_dot_worker(void * data) {
hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data;
const size_t n_row_tiles = job->n_row_tiles;
const size_t n_col_tiles = job->n_col_tiles;
const size_t n_dot_tiles = job->n_dot_tiles;
const size_t n_tiles_per_bc = job->n_tiles_per_bc;
const __fp16 * restrict q_tiles = job->q_tiles;
const __fp16 * restrict k_tiles = job->k_tiles;
__fp16 * restrict s_tiles = job->s_tiles;
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
Q6_bias_mxmem2_A((void *) job->hmx_scales);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < n_col_tiles; ++c) {
const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS;
const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS;
__fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS;
for (size_t k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
}
typedef struct {
__fp16 * o_curr;
const __fp16 * o_prev;
const __fp16 * p_tiles;
const __fp16 * v_tiles;
const __fp16 * d_tiles;
uint8_t * hmx_scales;
size_t n_row_tiles;
size_t n_col_tiles;
size_t n_row_tiles_g_br;
size_t n_tiles_per_bc;
size_t DV;
} hmx_fa_o_update_job_t;
static void hmx_fa_o_update_worker(void * data) {
hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data;
const size_t n_row_tiles = job->n_row_tiles;
const size_t n_col_tiles = job->n_col_tiles;
const size_t n_row_tiles_g_br = job->n_row_tiles_g_br;
const size_t n_tiles_per_bc = job->n_tiles_per_bc;
const size_t DV_tiles = job->DV / 32;
const __fp16 * restrict d_tiles = job->d_tiles;
const __fp16 * restrict p_tiles = job->p_tiles;
const __fp16 * restrict v_tiles = job->v_tiles;
const __fp16 * restrict o_prev = job->o_prev;
__fp16 * restrict o_curr = job->o_curr;
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(DV_tiles > 0);
Q6_bias_mxmem2_A((void *) job->hmx_scales);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS;
const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS;
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS;
const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS;
for (size_t k = 0; k < n_col_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
p_tile_in += HMX_FP16_TILE_N_ELMS;
v_tile_in += HMX_FP16_TILE_N_ELMS;
}
__fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS;
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
}
}
typedef struct {
__fp16 * o_curr; const __fp16 * o_prev; const __fp16 * d_tiles; uint8_t * hmx_scales;
size_t n_row_tiles;
size_t n_row_tiles_g_br;
size_t DV;
} hmx_fa_o_norm_job_t;
static void hmx_fa_o_norm_worker(void * data) {
hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data;
const size_t n_row_tiles = job->n_row_tiles;
const size_t n_row_tiles_g_br = job->n_row_tiles_g_br;
const size_t DV_tiles = job->DV / 32;
const __fp16 * restrict d_tiles = job->d_tiles;
const __fp16 * restrict o_prev = job->o_prev;
__fp16 * restrict o_curr = job->o_curr;
__builtin_assume(n_row_tiles > 0);
__builtin_assume(DV_tiles > 0);
Q6_bias_mxmem2_A((void *) job->hmx_scales);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS;
const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS;
__fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS;
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
Q6_mxmem_AR_after_hf(o_out, 0);
}
}
}
static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs,
const struct hmx_fa_context * factx,
uint32_t kv_head,
size_t n_rows_g) {
if (factx->max_bias == 0.0f) {
for (size_t r = 0; r < n_rows_g; ++r) {
sargs->slopes[r] = 1.0f;
}
return;
}
const uint32_t G = factx->G;
const uint32_t n_head_log2 = factx->n_head_log2;
const float m0 = factx->m0;
const float m1 = factx->m1;
for (size_t r = 0; r < n_rows_g; ++r) {
const uint32_t h = kv_head * G + r % G;
sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1);
}
}
int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = octx->src[0];
const struct htp_tensor * k = octx->src[1];
const struct htp_tensor * v = octx->src[2];
const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL;
const struct htp_tensor * dst = octx->dst;
struct htp_context * const ctx = octx->ctx;
if (!ctx->hmx_enabled) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; const uint32_t neq2 = q->ne[2]; const uint32_t neq3 = q->ne[3];
const uint32_t nek0 = k->ne[0]; const uint32_t nek1 = k->ne[1];
const uint32_t nev0 = v->ne[0];
const uint32_t DK = neq0;
const uint32_t DV = nev0;
if (DK % 32 != 0 || DV % 32 != 0) {
return HTP_STATUS_NO_SUPPORT;
}
if (neq1 < 32) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t n_kv_heads = k->ne[2];
const uint32_t G = neq2 / n_kv_heads;
const uint32_t n_threads = octx->n_threads;
size_t Br, Bc;
const size_t vtcm_budget = ctx->vtcm_size;
if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2);
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
struct hmx_fa_context factx;
memset(&factx, 0, sizeof(factx));
factx.octx = octx;
factx.n_threads = octx->ctx->n_threads;
factx.DK = DK;
factx.DV = DV;
factx.n_kv = nek1;
factx.n_kv_heads = n_kv_heads;
factx.n_heads = neq2;
factx.G = G;
factx.neq1 = neq1;
factx.Br = (uint32_t) Br;
factx.Bc = (uint32_t) Bc;
factx.g_br = (uint32_t) g_br;
factx.n_kv_blocks = n_kv_blocks;
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
factx.use_pipeline = use_pipeline;
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
#ifdef HMX_FA_USE_EXP2_HF
if (logit_softcap == 0.0f) {
scale *= 1.44269504f; }
#endif
factx.scale = scale;
factx.max_bias = max_bias;
factx.logit_softcap = logit_softcap;
factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2));
factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096);
const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096);
const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096);
const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096);
const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096);
const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096);
const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096);
const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096);
const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256);
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256);
const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128);
const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096);
const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128);
uint8_t * vtcm_cur = ctx->vtcm_base;
factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes);
factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes);
factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes);
factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes);
factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes);
factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes);
factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes);
factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes);
factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes);
factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes);
factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes);
factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes);
factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads);
factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector);
factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256);
factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256);
factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes);
factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16);
factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes);
if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00));
hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale));
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
TIMER_DEFINE(total);
TIMER_DEFINE(q_load);
TIMER_DEFINE(kv_dma);
TIMER_DEFINE(k_interleave);
TIMER_DEFINE(v_interleave);
TIMER_DEFINE(qk_dot);
TIMER_DEFINE(softmax);
TIMER_DEFINE(o_update);
TIMER_DEFINE(o_norm);
TIMER_DEFINE(o_store);
TIMER_START(total);
dma_queue * const dma = ctx->dma[0];
const size_t size_k_row = nek0 * sizeof(__fp16);
const size_t size_v_row = nev0 * sizeof(__fp16);
const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128);
const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128);
const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS;
const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS;
const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16);
if (!factx.use_pipeline) {
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
}
hmx_fa_qk_job_t qk_job;
hmx_fa_o_update_job_t ou_job;
hmx_fa_o_norm_job_t on_job;
for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) {
for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) {
const uint32_t ik2 = kv_head;
const uint32_t ik3 = ib3 / (neq3 / k->ne[3]);
const uint32_t iv2 = kv_head;
const uint32_t iv3 = ib3 / (neq3 / v->ne[3]);
for (uint32_t q_start = 0; q_start < neq1; q_start += Br) {
const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start);
const size_t n_rows_g = n_q_rows * G;
const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS);
const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS;
TIMER_START(q_load);
if (n_rows_g < g_br) {
hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes);
}
fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(q_load);
hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes);
hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes);
hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2);
__fp16 * o_tile_prev = factx.vtcm_o_tiles[0];
__fp16 * o_tile_curr = factx.vtcm_o_tiles[1];
hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes);
size_t buf_idx = 0;
if (factx.n_kv_blocks > 0) {
const uint32_t kv_rows0 = hex_smin(Bc, nek1);
const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3];
dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1],
size_k_row, kv_rows0);
const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3];
dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1],
size_v_row, kv_rows0);
}
#define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \
do { \
has_mask_dma_var = false; \
if (mask && factx.mask_broadcast) { \
const uint32_t _im3 = ib3 % mask->ne[3]; \
const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \
(kv_start_val) * sizeof(__fp16); \
dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \
(kv_rows_val) * sizeof(__fp16), n_q_rows); \
has_mask_dma_var = true; \
} \
} while (0)
#define MASK_DMA_POP(has_mask_dma_var) \
do { \
if (has_mask_dma_var) { \
dma_queue_pop(dma); \
} \
} while (0)
#define DMA_PREFETCH_KV(blk_val) \
do { \
if ((blk_val) < factx.n_kv_blocks) { \
const uint32_t _ns = (blk_val) * Bc; \
const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \
size_t _nb = 1 - buf_idx; \
const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \
dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \
const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \
dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \
} \
} while (0)
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
if (factx.use_pipeline) {
struct hmx_queue * hmx_q = ctx->hmx_queue;
for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) {
const uint32_t kv_start = kv_blk * Bc;
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
TIMER_START(kv_dma);
dma_queue_pop(dma); dma_queue_pop(dma); TIMER_STOP(kv_dma);
bool has_mask_dma = false;
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
if (kv_blk > 0) {
ou_job.o_curr = o_tile_curr;
ou_job.o_prev = o_tile_prev;
ou_job.p_tiles = factx.vtcm_p_tiles;
ou_job.v_tiles = factx.vtcm_v_tiles;
ou_job.d_tiles = factx.vtcm_d_tiles;
ou_job.hmx_scales = factx.vtcm_hmx_scales_id;
ou_job.n_row_tiles = n_row_tiles;
ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS);
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
ou_job.n_tiles_per_bc = n_tiles_per_bc;
ou_job.DV = DV;
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
}
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
if (kv_blk > 0) {
hmx_queue_pop(hmx_q);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
qk_job.q_tiles = factx.vtcm_q_tiles;
qk_job.k_tiles = factx.vtcm_k_tiles;
qk_job.s_tiles = factx.vtcm_s_tiles;
qk_job.n_row_tiles = n_row_tiles;
qk_job.n_col_tiles = n_col_tiles;
qk_job.n_dot_tiles = DK / 32;
qk_job.n_tiles_per_bc = n_tiles_per_bc;
qk_job.hmx_scales = factx.vtcm_hmx_scales_qk;
TIMER_START(qk_dot);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job));
DMA_PREFETCH_KV(kv_blk + 1);
TIMER_START(v_interleave);
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
hmx_queue_pop(hmx_q);
TIMER_STOP(qk_dot);
MASK_DMA_POP(has_mask_dma);
fa_softmax_args_t sargs;
memset(&sargs, 0, sizeof(sargs));
sargs.factx = &factx;
sargs.kv_rows = kv_rows;
sargs.n_rows_g = n_rows_g;
sargs.n_col_tiles = n_col_tiles;
sargs.n_tiles_per_bc = n_tiles_per_bc;
sargs.n_row_tiles = n_row_tiles;
sargs.n_row_tiles_g_br = n_row_tiles_g_br;
sargs.Bc = Bc;
sargs.G = G;
sargs.kv_head = kv_head;
sargs.kv_start = kv_start;
sargs.q_start = q_start;
sargs.ib3 = ib3;
sargs.has_alibi = (factx.max_bias != 0.0f);
sargs.mask = mask;
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
buf_idx = 1 - buf_idx;
}
if (factx.n_kv_blocks > 0) {
const uint32_t last_blk = factx.n_kv_blocks - 1;
const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS);
ou_job.o_curr = o_tile_curr;
ou_job.o_prev = o_tile_prev;
ou_job.p_tiles = factx.vtcm_p_tiles;
ou_job.v_tiles = factx.vtcm_v_tiles;
ou_job.d_tiles = factx.vtcm_d_tiles;
ou_job.hmx_scales = factx.vtcm_hmx_scales_id;
ou_job.n_row_tiles = n_row_tiles;
ou_job.n_col_tiles = last_cols;
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
ou_job.n_tiles_per_bc = n_tiles_per_bc;
ou_job.DV = DV;
TIMER_START(o_update);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
hmx_queue_pop(hmx_q);
TIMER_STOP(o_update);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
} else {
for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) {
const uint32_t kv_start = kv_blk * Bc;
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
TIMER_START(kv_dma);
dma_queue_pop(dma); dma_queue_pop(dma); TIMER_STOP(kv_dma);
bool has_mask_dma = false;
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
DMA_PREFETCH_KV(kv_blk + 1);
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
TIMER_START(qk_dot);
{
const size_t n_dot_tiles = (size_t) (DK / 32);
const __fp16 * restrict q_base = factx.vtcm_q_tiles;
const __fp16 * restrict k_base = factx.vtcm_k_tiles;
__fp16 * restrict s_base = factx.vtcm_s_tiles;
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < n_col_tiles; ++c) {
const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK;
const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK;
__fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS;
for (size_t k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
}
TIMER_STOP(qk_dot);
MASK_DMA_POP(has_mask_dma);
fa_softmax_args_t sargs;
memset(&sargs, 0, sizeof(sargs));
sargs.factx = &factx;
sargs.kv_rows = kv_rows;
sargs.n_rows_g = n_rows_g;
sargs.n_col_tiles = n_col_tiles;
sargs.n_tiles_per_bc = n_tiles_per_bc;
sargs.n_row_tiles = n_row_tiles;
sargs.n_row_tiles_g_br = n_row_tiles_g_br;
sargs.Bc = Bc;
sargs.G = G;
sargs.kv_head = kv_head;
sargs.kv_start = kv_start;
sargs.q_start = q_start;
sargs.ib3 = ib3;
sargs.has_alibi = (factx.max_bias != 0.0f);
sargs.mask = mask;
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
TIMER_START(v_interleave);
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
TIMER_START(o_update);
{
const size_t DV_tiles = (size_t) (DV / 32);
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
const __fp16 * restrict p_base = factx.vtcm_p_tiles;
const __fp16 * restrict v_base = factx.vtcm_v_tiles;
const __fp16 * restrict op_base = o_tile_prev;
__fp16 * restrict oc_base = o_tile_curr;
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(DV_tiles > 0);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS;
const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS;
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS;
const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS;
for (size_t k = 0; k < n_col_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
p_tile_in += HMX_FP16_TILE_N_ELMS;
v_tile_in += HMX_FP16_TILE_N_ELMS;
}
__fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS;
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
}
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
TIMER_STOP(o_update);
buf_idx = 1 - buf_idx;
} }
TIMER_START(o_norm);
{
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
if (factx.use_pipeline) {
on_job.o_curr = o_tile_curr;
on_job.o_prev = o_tile_prev;
on_job.d_tiles = factx.vtcm_d_tiles;
on_job.hmx_scales = factx.vtcm_hmx_scales_id;
on_job.n_row_tiles = n_row_tiles;
on_job.n_row_tiles_g_br = n_row_tiles_g_br;
on_job.DV = DV;
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job));
hmx_queue_pop(ctx->hmx_queue);
} else {
const size_t DV_tiles = (size_t) (DV / 32);
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
const __fp16 * restrict op_base = o_tile_prev;
__fp16 * restrict oc_base = o_tile_curr;
__builtin_assume(n_row_tiles > 0);
__builtin_assume(DV_tiles > 0);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS;
const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS;
__fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS;
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
Q6_mxmem_AR_after_hf(o_out, 0);
}
}
}
}
TIMER_STOP(o_norm);
TIMER_START(o_store);
fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(o_store);
#undef MASK_DMA_PUSH
#undef MASK_DMA_POP
#undef DMA_PREFETCH_KV
} } }
if (factx.use_pipeline) {
hmx_queue_suspend(ctx->hmx_queue);
} else {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total),
TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave));
FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax),
TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store));
#endif
return HTP_STATUS_OK;
}