#include "models.h"
void llama_model_step35::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_rot_full = hparams.n_rot_full / 2;
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false);
ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 45: type = LLM_TYPE_196B_A11B; break;
default: type = LLM_TYPE_UNKNOWN;
}
}
void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
LLAMA_LOAD_LOCALS;
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
const bool mtp_only = (hparams.nextn_predict_layers > 0) &&
(ml.get_weight("blk.0.attn_norm.weight") == nullptr);
const std::string mtp_probe = "blk." + std::to_string(n_main) + ".nextn.eh_proj.weight";
const bool trunk_only = (hparams.nextn_predict_layers > 0) &&
(ml.get_weight(mtp_probe.c_str()) == nullptr);
const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0;
const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags);
uint32_t n_rot_max = 0;
for (int i = 0; i < n_layer; ++i) {
n_rot_max = std::max(n_rot_max, hparams.n_rot(i));
}
if (n_rot_max == 0) {
n_rot_max = n_rot;
}
auto load_block_trunk = [&](int i, int flags) {
auto & layer = layers[i];
const uint32_t n_head_l = hparams.n_head(i);
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
} else {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
}
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
const int64_t n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
};
auto load_block_mtp = [&](int i, bool is_first_mtp) {
auto & layer = layers[i];
const uint32_t n_head_l = hparams.n_head(i);
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
} else {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
}
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
const int64_t n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
};
for (int i = 0; i < (int) n_main; ++i) {
load_block_trunk(i, trunk_flags);
}
for (int i = (int) n_main; i < n_layer; ++i) {
load_block_mtp(i, i == (int) n_main);
}
}
std::unique_ptr<llm_graph_context> llama_model_step35::build_arch_graph(const llm_graph_params & params) const {
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
return std::make_unique<graph_mtp>(*this, params);
}
return std::make_unique<graph>(*this, params);
}
llama_model_step35::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
const uint32_t n_head_l = hparams.n_head(il);
const uint32_t n_head_kv_l = hparams.n_head_kv(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
cur = inpL;
cb(cur, "attn_norm_in", il);
{
cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
if (model.layers[il].attn_q_norm) {
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
}
if (model.layers[il].attn_k_norm) {
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
}
const bool is_swa = hparams.is_swa(il);
ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
const int64_t n_rot_l = hparams.n_rot(il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);
const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
ggml_tensor * attn_out = build_attn(inp_attn,
nullptr, nullptr, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(attn_out, "attn_out", il);
if (model.layers[il].wqkv_gate) {
ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); cb(gate, "attn_gate", il);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il);
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
cb(gate_3d, "attn_gate_3d", il);
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "attn_gated_3d", il);
attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
cb(attn_out, "attn_gated", il);
}
cur = build_lora_mm(model.layers[il].wo, attn_out, model.layers[il].wo_s);
cb(cur, "attn_proj", il);
}
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * sh_out = build_ffn(cur,
model.layers[il].ffn_up_shexp, nullptr, nullptr,
model.layers[il].ffn_gate_shexp, nullptr, nullptr,
model.layers[il].ffn_down_shexp, nullptr, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(sh_out, "ffn_shared_out", il);
cur = ggml_add(ctx0, moe_out, sh_out);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = inpL;
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "STEP35 MTP requires nextn_predict_layers > 0");
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
const uint32_t n_head_l = hparams.n_head(il);
const uint32_t n_head_kv_l = hparams.n_head_kv(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
cb(tok_embd, "mtp_tok_embd", il);
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
cb(e_norm, "mtp_enorm", il);
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, 0);
cb(concat, "mtp_concat", il);
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
cb(cur, "mtp_eh_proj", il);
ggml_tensor * inpSA = cur;
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s);
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
cb(Qcur, "mtp_Qcur", il);
cb(Kcur, "mtp_Kcur", il);
cb(Vcur, "mtp_Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
if (layer.attn_q_norm) {
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "mtp_Qcur_normed", il);
}
if (layer.attn_k_norm) {
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "mtp_Kcur_normed", il);
}
const bool is_swa = hparams.is_swa(il);
ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
const int64_t n_rot_l = hparams.n_rot(il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "mtp_Qcur_pos", il);
cb(Kcur, "mtp_Kcur_pos", il);
const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
ggml_tensor * attn_out = build_attn(inp_attn,
nullptr, nullptr, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(attn_out, "mtp_attn_out", il);
if (layer.wqkv_gate) {
ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); cb(gate, "mtp_attn_gate", il);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "mtp_attn_gate_sigmoid", il);
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
cb(gate_3d, "mtp_attn_gate_3d", il);
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "mtp_attn_gated_3d", il);
attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
cb(attn_out, "mtp_attn_gated", il);
}
cur = build_lora_mm(layer.wo, attn_out, layer.wo_s);
cb(cur, "mtp_attn_proj", il);
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "mtp_attn_residual", il);
ggml_tensor * ffn_inp = cur;
cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_ffn_norm", il);
if (layer.ffn_gate_inp == nullptr) {
cur = build_ffn(cur,
layer.ffn_up, layer.ffn_up_b, nullptr,
layer.ffn_gate, layer.ffn_gate_b, nullptr,
layer.ffn_down, layer.ffn_down_b, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "mtp_ffn_out", il);
} else {
ggml_tensor * moe_out = build_moe_ffn(cur,
layer.ffn_gate_inp,
layer.ffn_up_exps,
layer.ffn_gate_exps,
layer.ffn_down_exps,
layer.ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "mtp_ffn_moe_out", il);
ggml_tensor * sh_out = build_ffn(cur,
layer.ffn_up_shexp, nullptr, nullptr,
layer.ffn_gate_shexp, nullptr, nullptr,
layer.ffn_down_shexp, nullptr, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(sh_out, "mtp_ffn_shared_out", il);
cur = ggml_add(ctx0, moe_out, sh_out);
cb(cur, "mtp_ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_post_ffn", il);
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)");
cur = build_lora_mm(head_w, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}