#include "../llama-build-context.h"
#include "../llama-model.h"
#include "../llama-context.h"
#include <vector>
ggml_tensor * llm_build_context::build_deepseek2_tp_attention(
ggml_cgraph * gf, int il,
ggml_tensor * inpL,
ggml_tensor * KQ_mask, ggml_tensor * inp_pos,
ggml_tensor * rope_cache,
float kq_scale, float attn_factor_scaled,
bool use_f32_attn_precision,
bool is_lite,
bool pp_opt) {
if (!lctx.cparams.flash_attn || lctx.cparams.mla_attn < 1) {
GGML_ABORT("-sm graph for MLA archs (DEEPSEEK2/GLM_DSA/MISTRAL4) requires -fa on and -mla >= 1. "
"Got mla_attn=%d, flash_attn=%d.",
(int)lctx.cparams.mla_attn, (int)lctx.cparams.flash_attn);
}
auto wo_split = (const ggml_split_tensor_t *)model.layers[il].wo->extra;
GGML_ASSERT(wo_split);
const int n_device = wo_split->n_device;
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k(il) - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const uint32_t n_embd_head_k = hparams.n_embd_head_k(il);
const uint32_t n_embd_head_v = hparams.n_embd_head_v(il);
auto cache_repl = (const ggml_split_tensor_t *)kv_self.k_l[il]->extra;
if (!cache_repl) {
LLAMA_LOG_ERROR("%s: no cache split for layer %d?\n", __func__, il);
}
GGML_ASSERT(cache_repl);
GGML_ASSERT(cache_repl->n_device == n_device);
std::vector<ggml_tensor *> attn_partials(n_device, nullptr);
bool input_added = false;
std::vector<int> head_offsets(n_device + 1, 0);
for (int idx = 0; idx < n_device; ++idx) {
int n_h_id = 0;
if (wo_split->splits[idx]) {
n_h_id = (int)(wo_split->splits[idx]->ne[0] / n_embd_head_v);
}
head_offsets[idx + 1] = head_offsets[idx] + n_h_id;
}
for (int id = 0; id < n_device; ++id) {
if (!wo_split->splits[id]) continue;
const int il_id = 1000 * il + id;
auto input = get_input_tensor_sm_graph(ctx0, inpL, id);
auto attn_norm_split = (const ggml_split_tensor_t *)model.layers[il].attn_norm->extra;
GGML_ASSERT(attn_norm_split);
ggml_tensor * cur = llm_build_norm(ctx0, input, hparams,
attn_norm_split->splits[id], nullptr, LLM_NORM_RMS, cb, il_id);
ggml_tensor * q;
if (!is_lite) {
auto wq_a_split = (const ggml_split_tensor_t *)model.layers[il].wq_a->extra;
auto wq_b_split = (const ggml_split_tensor_t *)model.layers[il].wq_b->extra;
GGML_ASSERT(wq_a_split && wq_b_split);
q = ggml_mul_mat(ctx0, wq_a_split->splits[id], cur);
ggml_build_forward_expand(gf, q);
auto q_a_norm_split = (const ggml_split_tensor_t *)model.layers[il].attn_q_a_norm->extra;
GGML_ASSERT(q_a_norm_split);
q = llm_build_norm(ctx0, q, hparams, q_a_norm_split->splits[id], nullptr, LLM_NORM_RMS, cb, il_id);
q = ggml_mul_mat(ctx0, wq_b_split->splits[id], q);
} else {
auto wq_split = (const ggml_split_tensor_t *)model.layers[il].wq->extra;
GGML_ASSERT(wq_split);
q = ggml_mul_mat(ctx0, wq_split->splits[id], cur);
ggml_build_forward_expand(gf, q);
}
cb(q, "q", il_id);
const int n_head_local = q->ne[0] / n_embd_head_k;
const size_t row_size_q = ggml_row_size(q->type, n_embd_head_k);
ggml_tensor * q_nope = ggml_view_3d(ctx0, q,
n_embd_head_qk_nope, n_head_local, n_tokens,
row_size_q, q->nb[1], 0);
ggml_tensor * q_rope = ggml_view_3d(ctx0, q,
n_embd_head_qk_rope, n_head_local, n_tokens,
row_size_q, q->nb[1],
ggml_row_size(q->type, n_embd_head_qk_nope));
auto wkv_a_mqa_split = (const ggml_split_tensor_t *)model.layers[il].wkv_a_mqa->extra;
GGML_ASSERT(wkv_a_mqa_split);
ggml_tensor * kv_rope_compressed = ggml_mul_mat(ctx0, wkv_a_mqa_split->splits[id], cur);
ggml_build_forward_expand(gf, kv_rope_compressed);
ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compressed,
n_embd_head_qk_rope, 1, n_tokens,
kv_rope_compressed->nb[1], kv_rope_compressed->nb[1],
ggml_row_size(kv_rope_compressed->type, kv_lora_rank));
ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compressed,
kv_lora_rank, n_tokens, kv_rope_compressed->nb[1], 0);
if (rope_cache) {
q_rope = ggml_rope_fast(ctx0, q_rope, rope_cache);
k_rope = ggml_rope_fast(ctx0, k_rope, rope_cache);
} else {
q_rope = ggml_rope_ext(ctx0, q_rope, inp_pos, nullptr, n_rot, rope_type,
n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow);
k_rope = ggml_rope_ext(ctx0, k_rope, inp_pos, nullptr, n_rot, rope_type,
n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow);
}
{
auto kv_a_norm_split = (const ggml_split_tensor_t *)model.layers[il].attn_kv_a_norm->extra;
GGML_ASSERT(kv_a_norm_split);
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
kv_a_norm_split->splits[id], NULL, LLM_NORM_RMS, cb, il_id);
}
ggml_tensor * cache_local = cache_repl->splits[id];
const auto row_size_cache = ggml_row_size(cache_local->type, kv_lora_rank + n_embd_head_qk_rope);
ggml_tensor * cache_write_view = ggml_view_2d(ctx0, cache_local,
cache_local->ne[0], n_tokens, row_size_cache, row_size_cache * kv_head);
ggml_tensor * kvr = ggml_concat(ctx0, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), kv_compressed, 0);
if (cparams.k_cache_hadamard) {
kvr = ggml_hadamard(ctx0, kvr, 64);
}
const int cc_idx = 2 * n_device * il + 2 * id;
GGML_ASSERT(cc_idx + 1 < (int)lctx.cache_copies.size());
lctx.cache_copies[cc_idx + 0].cpy = ggml_cpy(ctx0, kvr, cache_write_view);
lctx.cache_copies[cc_idx + 0].step = row_size_cache;
ggml_build_forward_expand(gf, lctx.cache_copies[cc_idx + 0].cpy);
ggml_tensor * kv_cache = ggml_view_2d(ctx0, cache_local,
kv_lora_rank + n_embd_head_qk_rope, n_kv,
row_size_cache, 0);
cb(kv_cache, "kv_cache", il_id);
constexpr int k_pp_opt_min_kv = 1024;
const bool tp_pp_opt = pp_opt
&& (int)n_kv >= k_pp_opt_min_kv
&& model.layers[il].wk_b
&& model.layers[il].wv_b
&& model.layers[il].wk_b_pp;
ggml_tensor * kqv_2d;
if (tp_pp_opt) {
auto wv_b_pp_split_raw = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra;
GGML_ASSERT(wv_b_pp_split_raw);
ggml_tensor * wv_b_local_pp = wv_b_pp_split_raw->splits[id];
ggml_tensor * kv_cache_nope = ggml_view_2d(ctx0, cache_local,
kv_lora_rank, n_kv,
row_size_cache,
ggml_row_size(cache_local->type, n_embd_head_qk_rope));
cb(kv_cache_nope, "kv_cache_nope_pp", il_id);
ggml_tensor * kv_cache_rope_view = ggml_view_3d(ctx0, cache_local,
n_embd_head_qk_rope, n_kv, 1,
row_size_cache, cache_local->nb[2], 0);
cb(kv_cache_rope_view, "kv_cache_rope_pp", il_id);
if (cparams.k_cache_hadamard) {
kv_cache_rope_view = ggml_hadamard(ctx0, kv_cache_rope_view, 64);
if (!model.khad_pretransformed) {
kv_cache_nope = ggml_hadamard(ctx0, kv_cache_nope, 64);
}
}
const auto kv_type = GGML_TYPE_F16;
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope;
repeater.ne[1] = n_kv;
repeater.ne[2] = n_head_local;
repeater.ne[3] = 1;
ggml_tensor * k_rope_rep;
if (kv_cache_rope_view->type == kv_type) {
k_rope_rep = ggml_repeat(ctx0, kv_cache_rope_view, &repeater);
} else {
auto kv_rope_f16 = ggml_cast(ctx0, kv_cache_rope_view, kv_type);
k_rope_rep = ggml_repeat(ctx0, kv_rope_f16, &repeater);
}
cb(k_rope_rep, "k_rope_rep_pp", il_id);
auto wv_b_2d = ggml_reshape_2d(ctx0, wv_b_local_pp,
kv_lora_rank, n_head_local * n_embd_head_v);
ggml_tensor * v_2d = ggml_mul_mat(ctx0, wv_b_2d, kv_cache_nope);
cb(v_2d, "v_2d_pp", il_id);
ggml_tensor * v_f32 = ggml_view_3d(ctx0, v_2d,
n_embd_head_v, n_kv, n_head_local,
v_2d->nb[1],
n_embd_head_v * v_2d->nb[0],
0);
auto wk_b_pp_split = (const ggml_split_tensor_t *)model.layers[il].wk_b_pp->extra;
GGML_ASSERT(wk_b_pp_split);
ggml_tensor * wk_b_pp_local = wk_b_pp_split->splits[id];
GGML_ASSERT(wk_b_pp_local);
ggml_tensor * wk_b_T_2d = ggml_reshape_2d(ctx0, wk_b_pp_local,
kv_lora_rank, n_head_local * n_embd_head_qk_nope);
ggml_tensor * k_nope_2d = ggml_mul_mat(ctx0, wk_b_T_2d, kv_cache_nope);
cb(k_nope_2d, "k_nope_2d_pp", il_id);
ggml_tensor * k_nope_f32 = ggml_view_3d(ctx0, k_nope_2d,
n_embd_head_qk_nope, n_kv, n_head_local,
k_nope_2d->nb[1],
n_embd_head_qk_nope * k_nope_2d->nb[0],
0);
ggml_tensor * v = ggml_cast(ctx0, v_f32, kv_type);
ggml_tensor * k_nope = ggml_cast(ctx0, k_nope_f32, kv_type);
ggml_build_forward_expand(gf, v);
ggml_build_forward_expand(gf, k_nope);
ggml_tensor * k = ggml_concat(ctx0, k_rope_rep, k_nope, 0);
ggml_build_forward_expand(gf, k);
cb(k, "k_full_pp", il_id);
ggml_tensor * q = ggml_concat(ctx0, q_rope, q_nope, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
ggml_build_forward_expand(gf, q);
cb(q, "q_concat_pp", il_id);
ggml_tensor * kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask,
kq_scale, hparams.f_max_alibi_bias, 0.f);
if (use_f32_attn_precision || q->ne[1] <= 8) {
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
cb(kqv, "kqv_pp", il_id);
kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens);
} else {
auto wk_b_split = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra;
GGML_ASSERT(wk_b_split);
ggml_tensor * wk_b_local = wk_b_split->splits[id];
ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b_local, q_nope_perm);
ggml_tensor * q_combined = ggml_concat(ctx0,
ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0);
if (cparams.k_cache_hadamard) {
q_combined = ggml_hadamard(ctx0, q_combined, 64);
}
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, cache_local,
kv_lora_rank, n_kv,
row_size_cache,
ggml_row_size(cache_local->type, n_embd_head_qk_rope));
cb(kv_cache_lora, "kv_cache_lora", il_id);
ggml_tensor * kqv_compressed = ggml_flash_attn_ext(ctx0,
q_combined, kv_cache, kv_cache_lora, KQ_mask,
kq_scale, hparams.f_max_alibi_bias, 0.f);
cb(kqv_compressed, "kqv_compressed", il_id);
if (use_f32_attn_precision) {
ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32);
}
if (cparams.k_cache_hadamard && !model.khad_pretransformed) {
kqv_compressed = ggml_hadamard(ctx0, kqv_compressed, 64);
}
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
auto wv_b_split = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra;
GGML_ASSERT(wv_b_split);
ggml_tensor * wv_b_local = wv_b_split->splits[id];
ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_local, kqv_compressed);
if (n_tokens > 1) {
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
}
kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens);
}
ggml_tensor * partial = llm_build_lora_mm(lctx, ctx0, wo_split->splits[id], kqv_2d);
if (!input_added) {
partial = ggml_add(ctx0, partial, input);
input_added = true;
}
if (partial->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
partial = ggml_cast(ctx0, partial, lctx.cparams.reduce_type);
}
ggml_build_forward_expand(gf, partial);
attn_partials[id] = partial;
}
ggml_tensor * combined = ggml_reduce(ctx0, attn_partials.data(), n_device, GGML_OP_ADD);
ggml_build_forward_expand(gf, combined);
cb(combined, "attn_combined", il);
return combined;
}
ggml_tensor * llm_build_context::build_deepseek2_layer_attention(
ggml_cgraph * gf, int il,
ggml_tensor * inpL,
ggml_tensor * KQ_mask, ggml_tensor * inp_pos,
ggml_tensor * rope_cache,
float kq_scale, float attn_factor_scaled,
bool use_f32_attn_precision,
bool is_lite,
bool pp_opt) {
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k(0) - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const uint32_t q_lora_rank = hparams.n_lora_q;
ggml_tensor * cur;
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
{
ggml_tensor * q = nullptr;
ggml_tensor * kv_rope_compressed = nullptr;
ggml_tensor * q_rope;
ggml_tensor * q_nope;
ggml_tensor * k_rope;
ggml_tensor * kv_compressed;
if (model.layers[il].wkq_a_mqa) {
auto mqa = ggml_mul_mat(ctx0, model.layers[il].wkq_a_mqa, cur);
cb(mqa, "mqa", il);
size_t qnb1;
if (!is_lite) {
q = ggml_view_2d(ctx0, mqa, q_lora_rank, n_tokens, mqa->nb[1], 0);
q = llm_build_norm(ctx0, q, hparams, model.layers[il].attn_q_a_norm, NULL, LLM_NORM_RMS, cb, il);
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
qnb1 = q->nb[1];
cb(q, "q", il);
kv_rope_compressed = ggml_view_2d(ctx0, mqa, kv_lora_rank + n_embd_head_qk_rope, n_tokens, mqa->nb[1],
q_lora_rank*ggml_element_size(mqa));
} else {
q = ggml_view_2d(ctx0, mqa, n_embd_k_gqa, n_tokens, mqa->nb[1], 0);
kv_rope_compressed = ggml_view_2d(ctx0, mqa, kv_lora_rank + n_embd_head_qk_rope, n_tokens, mqa->nb[1],
n_embd_k_gqa*ggml_element_size(mqa));
qnb1 = mqa->nb[1];
}
q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k(il)), qnb1, 0);
q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k(il)), qnb1, ggml_row_size(q->type, n_embd_head_qk_nope));
k_rope = ggml_view_3d(ctx0, kv_rope_compressed, n_embd_head_qk_rope, 1, n_tokens,
mqa->nb[1], mqa->nb[1], ggml_row_size(kv_rope_compressed->type, kv_lora_rank));
kv_compressed = ggml_view_2d(ctx0, kv_rope_compressed, kv_lora_rank, n_tokens, mqa->nb[1], 0);
}
else {
if (!is_lite) {
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
cb(q, "q", il);
kv_rope_compressed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_rope_compressed, "kv_rope_compressed", il);
ggml_build_forward_expand(gf, q);
ggml_build_forward_expand(gf, kv_rope_compressed);
q = llm_build_norm(ctx0, q, hparams, model.layers[il].attn_q_a_norm, NULL, LLM_NORM_RMS, cb, il);
cb(q, "q", il);
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
cb(q, "q", il);
} else {
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(q, "q", il);
kv_rope_compressed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_rope_compressed, "kv_rope_compressed", il);
ggml_build_forward_expand(gf, q);
ggml_build_forward_expand(gf, kv_rope_compressed);
}
q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k(il)),
ggml_row_size(q->type, hparams.n_embd_head_k(il) * n_head), 0);
q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k(il)),
ggml_row_size(q->type, hparams.n_embd_head_k(il) * n_head),
ggml_row_size(q->type, n_embd_head_qk_nope));
k_rope = ggml_view_3d(ctx0, kv_rope_compressed, n_embd_head_qk_rope, 1, n_tokens,
kv_rope_compressed->nb[1],
kv_rope_compressed->nb[1],
ggml_row_size(kv_rope_compressed->type, kv_lora_rank));
kv_compressed = ggml_view_2d(ctx0, kv_rope_compressed, kv_lora_rank, n_tokens,
kv_rope_compressed->nb[1], 0);
}
cb(q_nope, "q_nope", il);
cb(q_rope, "q_rope", il);
cb(k_rope, "k_rope", il);
cb(kv_compressed, "kv_compressed", il);
ggml_build_forward_expand(gf, q_rope);
ggml_build_forward_expand(gf, k_rope);
if (rope_cache) {
q_rope = ggml_rope_fast(ctx0, q_rope, rope_cache);
k_rope = ggml_rope_fast(ctx0, k_rope, rope_cache);
} else {
q_rope = ggml_rope_ext(ctx0, q_rope, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow);
k_rope = ggml_rope_ext(ctx0, k_rope, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow);
}
cb(q_rope, "q_rope", il);
cb(k_rope, "k_rope", il);
ggml_build_forward_expand(gf, q_rope);
ggml_build_forward_expand(gf, k_rope);
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il);
cb(kv_compressed, "kv_compressed", il);
if (lctx.cparams.mla_attn) {
ggml_tensor * kv_cache_trans = nullptr;
if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) {
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, kv_lora_rank,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size), ggml_row_size(kv_self.v_l[il]->type, kv_head));
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
kv_cache_trans = ggml_view_2d(ctx0, kv_self.v_l[il],
n_kv, kv_lora_rank,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
0);
cb(kv_cache_trans, "kv_cache_trans", il);
}
ggml_tensor * kvr = ggml_concat(ctx0, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), kv_compressed, 0);
cb(kvr, "kvr", il);
auto row_size = ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.k_l[il], kv_self.k_l[il]->ne[0], n_tokens,
row_size, row_size*kv_head);
lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, kvr, kv_cache_view);
lctx.cache_copies[2*il+0].step = row_size;
ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy);
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.k_l[il],
kv_lora_rank + n_embd_head_qk_rope, n_kv,
ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache", il);
ggml_tensor * kqv;
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) {
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank, n_kv, kv_self.k_l[il]->nb[1],
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope));
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
int n_max_head = n_head;
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
n_max_head = 1;
for (int niter = 2; niter < n_head; ++niter) {
if (n_head % niter == 0 && kv_f32_size/niter <= cparams.attn_max_batch) {
n_max_head = n_head/niter;
break;
}
}
}
GGML_ASSERT(n_head % n_max_head == 0);
auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_qk_rope, n_kv, 1,
kv_self.k_l[il]->nb[1], kv_self.k_l[il]->nb[2], 0);
auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.k_l[il]->type : GGML_TYPE_F16;
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
ggml_tensor * k_rope;
if (kv_cache_rope->type == kv_type) {
k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
} else {
auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16);
k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater);
}
cb(k_rope, "k_rope", il);
auto q = ggml_concat(ctx0, q_rope, q_nope, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_concat", il);
ggml_build_forward_expand(gf, q);
for (int iter = 0; iter < n_head/n_max_head; ++iter) {
auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head,
model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter);
auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope);
cb(kv_f32, "kv_f32", il);
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v_full, n_kv, n_max_head,
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v_full)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v_full),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v_full)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v_full), 0);
cb(k_nope_f32, "k_nope_f32", il);
auto v = ggml_cast(ctx0, v_f32, kv_type);
cb(v, "v", il);
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type);
cb(k_nope, "k_nope", il);
ggml_build_forward_expand(gf, k_nope);
ggml_build_forward_expand(gf, v);
auto k = ggml_concat(ctx0, k_rope, k_nope, 0);
cb(k, "k", il);
ggml_build_forward_expand(gf, k);
auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head,
q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
if (use_f32_attn_precision || q->ne[1] <= 8) {
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
cb(kqv, "kqv", il);
if (iter == 0) {
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens);
} else {
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
}
ggml_build_forward_expand(gf, cur);
}
}
else {
ggml_tensor * kqv_compressed = nullptr;
auto wk_b = model.layers[il].wk_b->ne[1] == kv_lora_rank ? model.layers[il].wk_b
: ggml_reshape_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head);
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
cb(q_nope2, "q_nope2", il);
ggml_tensor * q = ggml_concat(ctx0, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0);
cb(q, "q", il);
if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.k_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope));
cb(kv_cache_lora, "kv_cache_lora", il);
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
cb(kqv_compressed, "kqv_compressed", il);
if (use_f32_attn_precision) {
ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32);
}
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
else {
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.k_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope));
cb(kv_cache, "kv_cache_lora", il);
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
}
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
if (kv_cache->ne[1] < 256) {
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
}
cb(kqv_compressed, "kqv_compressed", il);
}
}
auto wv_b = model.layers[il].wv_b;
if (wv_b->ne[1] != n_embd_head_v) {
wv_b = ggml_reshape_3d(ctx0, wv_b, kv_lora_rank, n_embd_head_v, n_head);
cb(wv_b, "wv_b", il);
}
kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
cb(kqv, "kqv", il);
if (n_tokens > 1) {
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
cb(kqv, "kqv_perm", il);
}
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_2d", il);
}
ggml_build_forward_expand(gf, cur);
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
cb(cur, "kqv_out", il);
}
else {
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
cb(kv, "kv", il);
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v_full),
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v_full)),
0);
cb(k_nope, "k_nope", il);
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v_full, n_head, n_tokens,
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v_full)),
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v_full)*n_head),
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
cb(v_states, "v_states", il);
v_states = ggml_cont(ctx0, v_states);
cb(v_states, "v_states", il);
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v_full * n_head, n_tokens,
ggml_row_size(kv->type, hparams.n_embd_head_v_full * n_head),
0);
cb(v_states, "v_states", il);
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0);
cb(q_states, "q_states", il);
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_rope, q_rope), 0);
cb(k_states, "k_states", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
}
return cur;
}
ggml_cgraph * llm_build_context::build_deepseek2() {
const bool tp_mode = (model.split_mode == LLAMA_SPLIT_MODE_GRAPH ||
model.split_mode == LLAMA_SPLIT_MODE_ATTN);
#ifdef GGML_USE_VULKAN
const bool use_f32_attn_precision = true;
#else
const bool use_f32_attn_precision = lctx.cparams.graph_attn_precision == GGML_TYPE_F32;
#endif
ggml_cgraph * gf = new_graph_custom();
int32_t n_tokens = this->n_tokens;
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k(0)));
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * inp_pos = build_inp_pos();
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
bool pp_opt = n_tokens >= 128 && lctx.cparams.mla_attn > 1;
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
ggml_rope_cache(ctx0, inp_pos, nullptr, n_rot, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (model.arch != LLM_ARCH_GLM_DSA || !model.mtp || hparams.nextn_predict_layers == 0) {
GGML_ABORT("MTP tail is only wired for GLM_DSA models with NextN layers enabled");
}
ggml_tensor * hidden_states_from_main_model;
if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
} else {
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
}
ggml_set_name(hidden_states_from_main_model, "inp_mtp_states");
ggml_set_input(hidden_states_from_main_model);
lctx.inp_mtp_states = hidden_states_from_main_model;
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
cur = build_deepseek2_mtp(mtp_layer, hidden_states_from_main_model, gf, inp_pos, rope_cache);
ggml_build_forward_expand(gf, cur);
return gf;
}
int n_active_layers = hparams.n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_active_layers; ++il) {
struct ggml_tensor * inpSA = inpL;
bool is_tp_layer = tp_mode && model.layers[il].wo && model.layers[il].wo->extra;
if (is_tp_layer) {
cur = build_deepseek2_tp_attention(gf, il, inpL, KQ_mask, inp_pos, rope_cache,
kq_scale, attn_factor_scaled,
use_f32_attn_precision, is_lite, pp_opt);
} else {
cur = build_deepseek2_layer_attention(gf, il, inpL, KQ_mask, inp_pos, rope_cache,
kq_scale, attn_factor_scaled,
use_f32_attn_precision, is_lite, pp_opt);
}
if (il == n_active_layers - 1 && !lctx.cparams.mtp) {
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
cb(cur, "last_attn", il);
cb(inpSA, "last_ffn_inp", il);
}
struct ggml_tensor * ffn_inp;
if (is_tp_layer) {
ffn_inp = cur;
} else {
ffn_inp = ggml_add(ctx0, cur, inpSA);
}
cb(ffn_inp, "ffn_inp", il);
if (is_tp_layer) {
cur = ffn_inp;
} else {
cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
}
if ((uint32_t) il < hparams.n_layer_dense_lead) {
cur = llm_build_ffn(ctx0, lctx,
is_tp_layer ? model.layers[il].ffn_norm : nullptr, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il,
gf,
is_tp_layer);
cb(cur, "ffn_out", il);
} else if (is_tp_layer) {
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
model.layers[il].ffn_gate_inp, nullptr,
model.layers[il].ffn_up_exps, nullptr,
model.layers[il].ffn_gate_exps, nullptr,
model.layers[il].ffn_down_exps, nullptr,
model.layers[il].ffn_exp_probs_b,
model.layers[il].ffn_up_shexp, nullptr,
model.layers[il].ffn_gate_shexp, nullptr,
model.layers[il].ffn_down_shexp, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
cb, il, gf, false, model.layers[il].ffn_up_gate_exps);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
if (!is_tp_layer) {
cur = ggml_add(ctx0, cur, ffn_inp);
}
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_tensor * llm_build_context::build_deepseek2_mtp(
const llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos,
[[maybe_unused]] struct ggml_tensor * rope_cache) {
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_attn_precision = true;
#else
constexpr bool use_f32_attn_precision = false;
#endif
const int il = hparams.n_layer - 1;
const uint32_t n_embd_head_k_mtp = hparams.n_embd_head_k(il);
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k_mtp));
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
if (mtp_embd_weights == nullptr) {
mtp_embd_weights = model.tok_embd;
}
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);
if (mtp_layer.nextn.eh_proj == nullptr) {
GGML_ABORT("GLM_DSA MTP requires nextn.eh_proj");
}
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
cb(combined, "mtp_concat", il);
ggml_tensor * cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);
struct ggml_tensor * inpSA = cur;
cur = build_deepseek2_layer_attention(gf, il, cur, KQ_mask, inp_pos, nullptr,
kq_scale, attn_factor_scaled,
use_f32_attn_precision, false, false);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "mtp_ffn_inp", il);
if (inp_out_ids) {
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}
cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
{
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
mtp_layer.ffn_gate_inp,
mtp_layer.ffn_up_exps,
mtp_layer.ffn_gate_exps,
mtp_layer.ffn_down_exps,
mtp_layer.ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
cb, il, gf, false, mtp_layer.ffn_up_gate_exps);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur,
mtp_layer.ffn_up_shexp, NULL, NULL,
mtp_layer.ffn_gate_shexp, NULL, NULL,
mtp_layer.ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "mtp_ffn_out_resid", il);
if (mtp_layer.nextn.shared_head_norm == nullptr) {
GGML_ABORT("GLM_DSA MTP requires nextn.shared_head_norm");
}
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "result_norm", -1);
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
if (mtp_head_weights == nullptr) {
mtp_head_weights = model.output;
}
cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);
cb(cur, "result_output", -1);
return cur;
}