diff --git a/common/arg.cpp b/common/arg.cpp
index 8f54ee38c1b..3a95fd71ac2 100644
@@ -3562,12 +3562,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
- {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
+ {"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
+ } else if (value == "mtp") {
+ params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
diff --git a/common/common.cpp b/common/common.cpp
index 793b8fee7b8..a821da2da5d 100644
@@ -1420,6 +1420,11 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
goto done;
}
+ if (llama_n_rs_seq(ctx) > 0) {
+ res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED;
+ goto done;
+ }
+
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
@@ -1490,6 +1495,12 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
+ {
+ // enable partial rollback only for MTP, each recurrent slot requires memory
+ // and MTP uses max 3-4 slots vs other techniques
+ const bool has_mtp_spec = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
+ cparams.n_rs_seq = has_mtp_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
+ }
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;
diff --git a/common/common.h b/common/common.h
index a564b3b8c2b..0c28c8b3497 100644
@@ -159,6 +159,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
+ COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -347,11 +348,17 @@ struct common_params_speculative_ngram_cache {
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
};
+struct common_params_speculative_mtp {
+ llama_model * model = nullptr;
+ llama_context_params cparams;
+};
+
struct common_params_speculative {
// TODO: become a vector in order to support "chains of speculators"
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
common_params_speculative_draft draft;
+ common_params_speculative_mtp mtp;
common_params_speculative_ngram_mod ngram_mod;
common_params_speculative_ngram_map ngram_simple;
@@ -879,9 +886,10 @@ std::string common_get_model_endpoint();
//
enum common_context_seq_rm_type {
- COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
- COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
- COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
+ COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
+ COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
+ COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
+ COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq
};
// check if the llama_context can remove sequences
diff --git a/common/speculative.cpp b/common/speculative.cpp
index bbf88fa6e71..b234c0a9617 100644
@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_EAGLE3,
+ COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
+ {"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -599,6 +601,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
}
};
+struct common_speculative_state_mtp : public common_speculative_state {
+ llama_context * ctx_tgt = nullptr;
+ llama_context * ctx_mtp = nullptr;
+
+ llama_batch batch; // single token draft step
+ common_sampler * smpl = nullptr;
+ int32_t n_embd = 0;
+
+ uint16_t last_n_drafted = 0;
+ int32_t last_n_accepted = -1;
+
+ common_speculative_state_mtp(enum common_speculative_type type,
+ llama_context * ctx_tgt,
+ llama_context * ctx_mtp)
+ : common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
+ GGML_ASSERT(ctx_tgt && ctx_mtp);
+ const llama_model * model_mtp = llama_get_model(ctx_mtp);
+ n_embd = llama_model_n_embd(model_mtp);
+
+ {
+ common_params_sampling sparams;
+ sparams.no_perf = false;
+ sparams.top_k = 1;
+ sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
+ smpl = common_sampler_init(model_mtp, sparams);
+ }
+
+ // TODO: multiple seq support
+ batch = llama_batch_init(/*n_tokens=*/ 1, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
+ batch.token = (llama_token *) malloc(sizeof(llama_token));
+ batch.n_tokens = 1;
+ batch.n_seq_id[0] = 1;
+ batch.seq_id[0][0] = 0;
+ batch.logits[0] = 1;
+
+ llama_set_mtp(ctx_tgt, ctx_mtp);
+ }
+
+ ~common_speculative_state_mtp() override {
+ llama_set_mtp(ctx_tgt, nullptr);
+ llama_batch_free(batch);
+ common_sampler_free(smpl);
+ if (ctx_mtp) {
+ llama_free(ctx_mtp);
+ }
+ }
+
+ void begin(const llama_tokens & prompt) override {
+ last_n_accepted = -1;
+ last_n_drafted = 0;
+
+ const int32_t N = (int32_t) prompt.size();
+ if (N <= 0) {
+ return;
+ }
+ const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
+ if (pos_max < N - 1) {
+ LOG_WRN("%s: ctx_mtp pos_max=%d < N-1=%d — "
+ "streaming hook may not be registered or not all prefill rows "
+ "have logits=true. Drafts may degrade.\n",
+ __func__, (int) pos_max, N - 1);
+ }
+ }
+
+ void draft(
+ const common_params_speculative & params,
+ const llama_tokens & prompt_tgt,
+ llama_token id_last,
+ llama_tokens & draft_tokens) override {
+ GGML_UNUSED(prompt_tgt);
+ draft_tokens.clear();
+
+ // accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
+ // TODO: check if bug in other spec states
+ if (last_n_drafted > 0) {
+ const int32_t n_to_drop = (int32_t) last_n_drafted - 1;
+ if (n_to_drop > 0) {
+ const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
+ if (pos_max >= 0) {
+ const llama_pos drop_from = pos_max - n_to_drop + 1;
+ llama_memory_seq_rm(llama_get_memory(ctx_mtp), 0, drop_from, -1);
+ }
+ }
+ last_n_drafted = 0;
+ last_n_accepted = 0;
+ }
+
+ const int32_t n_max = std::max(1, params.draft.n_max);
+ const size_t row_bytes = (size_t) n_embd * sizeof(float);
+
+ llama_token cond_tok = id_last;
+ llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0) + 1;
+
+ // auto-regressive loop for MTP
+ for (int32_t k = 0; k < n_max; ++k) {
+ ggml_tensor * src;
+ int32_t src_row;
+ if (k == 0) {
+ src = llama_context_get_t_h_pre_norm(ctx_tgt);
+ if (last_n_accepted < 0) {
+ // First draft after begin(): trunk's most recent decode is
+ // the last prefill ubatch; its last row is h_{N-1}.
+ src_row = (src && src->ne[1] > 0) ? (int32_t) src->ne[1] - 1 : 0;
+ } else {
+ src_row = last_n_accepted;
+ }
+ llama_synchronize(ctx_tgt);
+ } else {
+ // for the AR path get the mtp_out from the mtp ctx
+ src = llama_context_get_t_mtp_out(ctx_mtp);
+ src_row = src ? (int32_t) src->ne[1] - 1 : 0;
+ llama_synchronize(ctx_mtp);
+ }
+ if (!src) {
+ LOG_WRN("%s: missing source tensor at k=%d; stopping chain\n", __func__, k);
+ return;
+ }
+ ggml_backend_tensor_get(src, batch.embd,
+ (size_t) src_row * row_bytes, row_bytes);
+
+ batch.token[0] = cond_tok;
+ batch.pos[0] = pos;
+
+ const int32_t dec_rc = llama_decode(ctx_mtp, batch);
+ if (dec_rc != 0) {
+ LOG_DBG("%s: llama_decode rc=%d at k=%d; stopping chain\n", __func__, dec_rc, k);
+ return;
+ }
+
+ const llama_token best = common_sampler_sample(smpl, ctx_mtp, 0);
+ common_sampler_accept(smpl, best, /*accept_grammar=*/ false);
+ draft_tokens.push_back(best);
+ cond_tok = best;
+ ++pos;
+ }
+
+ last_n_drafted = (uint16_t) draft_tokens.size();
+ }
+
+ void accept(uint16_t n_accepted) override {
+ const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
+ const int32_t n_drafted_last = (int32_t) last_n_drafted;
+ const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
+ if (pos_max < 0) {
+ last_n_accepted = (int32_t) n_accepted;
+ return;
+ }
+ if (n_to_drop > 0) {
+ const llama_pos drop_from = pos_max - n_to_drop + 1;
+ llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
+ /*p0=*/ drop_from, /*p1=*/ -1);
+ }
+ last_n_drafted = 0;
+ last_n_accepted = (int32_t) n_accepted;
+ }
+
+ int32_t n_max(const common_params_speculative & params) const override {
+ return std::max(1, params.draft.n_max);
+ }
+
+ int32_t n_min(const common_params_speculative & params) const override {
+ return std::max(1, params.draft.n_min);
+ }
+};
+
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_config config;
@@ -952,6 +1119,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
+ case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
@@ -983,11 +1151,24 @@ common_speculative * common_speculative_init(
}
}
+ llama_context * ctx_mtp = nullptr;
+ if (params.type == COMMON_SPECULATIVE_TYPE_MTP) {
+ ctx_mtp = llama_init_from_model(params.mtp.model, params.mtp.cparams);
+ if (ctx_mtp == nullptr) {
+ LOG_ERR("%s", "failed to create MTP context\n");
+ if (ctx_dft) {
+ llama_free(ctx_dft);
+ }
+ return nullptr;
+ }
+ }
+
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
bool has_draft = !params.draft.mparams.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
+ bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr);
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1034,6 +1215,9 @@ common_speculative * common_speculative_init(
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
}
+ if (has_mtp) {
+ configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
+ }
}
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1058,6 +1242,11 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
}
+ case COMMON_SPECULATIVE_TYPE_MTP: {
+ impls.push_back(std::make_unique<common_speculative_state_mtp>(
+ config.type, ctx_tgt, ctx_mtp));
+ break;
+ }
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 703e3783136..438b790dc4b 100644
@@ -2537,7 +2537,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
- struct ggml_tensor * state);
+ struct ggml_tensor * state,
+ bool keep_intermediates);
// custom operators
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2b3eb5b5ce6..820b8fb4dc0 100644
@@ -2933,7 +2933,9 @@ struct ggml_cplan ggml_graph_plan(
case GGML_OP_GATED_DELTA_NET:
{
const int64_t S_v = node->src[2]->ne[0];
- cur = S_v * sizeof(float) * n_tasks;
+ const bool keep_intermediates = (((const int32_t *) node->op_params)[0] != 0);
+ const int64_t per_thread = S_v + (keep_intermediates ? S_v * S_v : 0);
+ cur = per_thread * sizeof(float) * n_tasks;
} break;
case GGML_OP_COUNT:
{
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index a9bc21da6f0..91cce7cdd06 100644
@@ -10467,16 +10467,20 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const bool kda = (neg0 == S_v);
- // scratch layout per thread: [delta(S_v)]
- const int64_t scratch_per_thread = S_v;
+ const bool keep_intermediates = (bool) ggml_get_op_params_i32(dst, 0);
+
+ const int64_t per_thread = S_v + (keep_intermediates ? S_v * S_v : 0);
const int ith = params->ith;
- float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
+ float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
+ float * state_work = keep_intermediates ? (delta + S_v) : nullptr;
// output layout: [attn_scores | new_states]
// attn_scores: S_v * H * n_tokens * n_seqs floats
- // new_states: S_v * S_v * H * n_seqs floats
- const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+ // new_states: S_v * S_v * H * n_seqs floats (final only)
+ // S_v * S_v * H * n_seqs * T floats (T snaps, keep_intermediates)
+ const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+ const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
@@ -10499,9 +10503,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
- float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
+ float * s_out = keep_intermediates
+ ? state_work
+ : state_out_base + (iv3 * H + iv1) * S_v * S_v;
- // copy input state into output buffer and operate in-place
+ // copy input state into the working buffer and operate in-place
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
@@ -10552,6 +10558,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
}
attn_data += S_v * H; // advance to next token
+
+ if (keep_intermediates) {
+ float * curr_state_o = state_out_base + t * state_size_per_snap +
+ (iv3 * H + iv1) * S_v * S_v;
+ memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
+ }
}
}
}
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
index 6b44bec7317..fdd0aefe772 100644
@@ -1,6 +1,6 @@
#include "gated_delta_net.cuh"
-template <int S_v, bool KDA>
+template <int S_v, bool KDA, bool keep_intermediates_t>
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
gated_delta_net_cuda(const float * q,
const float * k,
@@ -37,7 +37,8 @@ gated_delta_net_cuda(const float * q,
float * attn_data = dst;
float * state = dst + attn_score_elems;
- const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+ const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+ const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // keep_intermediates_t only
state += state_offset;
curr_state += state_offset + col * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
@@ -135,17 +136,27 @@ gated_delta_net_cuda(const float * q,
}
attn_data += S_v * H;
+
+ if constexpr (keep_intermediates_t) {
+ float * curr_state = (dst + attn_score_elems) + t * state_size_per_token + state_offset;
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ curr_state[col * S_v + i] = s_shard[r];
+ }
+ }
}
- // Write state back to global memory (transposed layout)
+ if constexpr (!keep_intermediates_t) {
#pragma unroll
- for (int r = 0; r < rows_per_lane; r++) {
- const int i = r * warp_size + lane;
- state[col * S_v + i] = s_shard[r];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ state[col * S_v + i] = s_shard[r];
+ }
}
}
-template <bool KDA>
+template <bool KDA, bool keep_intermediates_t>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
@@ -169,26 +180,26 @@ static void launch_gated_delta_net(
switch (S_v) {
case 16:
- gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ gated_delta_net_cuda<16, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 32:
- gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ gated_delta_net_cuda<32, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 64: {
- gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ gated_delta_net_cuda<64, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 128: {
- gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ gated_delta_net_cuda<128, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
@@ -261,13 +272,27 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
cudaStream_t stream = ctx.stream();
+ const bool keep_intermediates = (((const int32_t *) dst->op_params)[0] != 0);
+
if (kda) {
- launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ if (keep_intermediates) {
+ launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ } else {
+ launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ }
} else {
- launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ if (keep_intermediates) {
+ launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ } else {
+ launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ }
}
}
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index ff74cafb5b7..d5d0d06e48f 100644
@@ -887,6 +887,7 @@ typedef struct {
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
+ int32_t keep_intermediates;
} ggml_metal_kargs_gated_delta_net;
typedef struct {
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 5fa162c875c..d090e39c3c6 100644
@@ -1601,6 +1601,8 @@ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
int ida = 0;
+ const int32_t keep_intermediates = (ggml_get_op_params_i32(op, 0) != 0) ? 1 : 0;
+
ggml_metal_kargs_gated_delta_net args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@@ -1637,6 +1639,7 @@ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
+ /*.keep_intermediates =*/ keep_intermediates,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c372eaedeae..a49d75f33cf 100644
@@ -2621,6 +2621,19 @@ kernel void kernel_gated_delta_net_impl(
dst_attn[t*args.ne21*S_v] = y*scale;
}
+ if (args.keep_intermediates) {
+ const uint s_off = args.ne23*args.ne22*args.ne21*S_v;
+ const uint snap_stride = S_v*S_v*args.ne21*args.ne23;
+ const uint state_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
+
+ device float * dst_snap = (device float *) (dst) + s_off + t*snap_stride + state_base;
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ dst_snap[is] = ls[j];
+ }
+ }
+
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
@@ -2629,11 +2642,13 @@ kernel void kernel_gated_delta_net_impl(
g_ptr += args.ne21*G;
}
- device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
+ if (!args.keep_intermediates) {
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
- FOR_UNROLL (short j = 0; j < NSG; j++) {
- const short is = tx*NSG + j;
- dst_state[is] = ls[j];
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ dst_state[is] = ls[j];
+ }
}
#undef S_v
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 423e01dbff1..a7da439151c 100644
@@ -1497,6 +1497,7 @@ struct vk_op_gated_delta_net_push_constants {
uint32_t sb1, sb2, sb3;
uint32_t neq1, rq3;
float scale;
+ uint32_t keep_intermediates;
};
struct vk_op_ssm_scan_push_constants {
@@ -10706,13 +10707,15 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
const float scale = 1.0f / sqrtf((float)S_v);
+ const uint32_t keep_intermediates = (uint32_t)(ggml_get_op_params_i32(dst, 0) != 0);
const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
- scale
+ scale,
+ keep_intermediates
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
@@ -16867,8 +16870,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
+ const bool keep_intermediates = (((const int32_t *) tensor->op_params)[0] != 0);
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
- src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
+ src_clone[2], src_clone[3], src_clone[4], src_clone[5], keep_intermediates);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
index 5e9f8308c1d..86a4caca294 100644
@@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters {
uint sb1, sb2, sb3;
uint neq1, rq3;
float scale;
+ uint keep_intermediates;
};
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
@@ -102,6 +103,7 @@ void main() {
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
+ const uint snap_stride = state_size * H * n_seqs;
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
@@ -161,9 +163,18 @@ void main() {
}
attn_off += S_V * H;
+
+ if (keep_intermediates != 0) {
+ const uint snap_base = s_off + t * snap_stride + state_base + col * S_V;
+ [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
+ data_dst[snap_base + r * LANES_PER_COLUMN + lane] = s_shard[r];
+ }
+ }
}
- [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
- data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
+ if (keep_intermediates == 0) {
+ [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
+ data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
+ }
}
}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 81343eeb14c..fedf4b2642b 100644
@@ -6176,7 +6176,8 @@ struct ggml_tensor * ggml_gated_delta_net(
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
- struct ggml_tensor * state) {
+ struct ggml_tensor * state,
+ bool keep_intermediates) {
GGML_ASSERT(ggml_is_contiguous_rows(q));
GGML_ASSERT(ggml_is_contiguous_rows(k));
GGML_ASSERT(ggml_is_contiguous_rows(v));
@@ -6202,9 +6203,8 @@ struct ggml_tensor * ggml_gated_delta_net(
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
- // concat output and new_state into a single tensor
- // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
- const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
+ const int64_t state_rows = keep_intermediates ? n_tokens * S_v * n_seqs : S_v * n_seqs;
+ const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_GATED_DELTA_NET;
@@ -6215,6 +6215,9 @@ struct ggml_tensor * ggml_gated_delta_net(
result->src[4] = beta;
result->src[5] = state;
+ int32_t flag = keep_intermediates ? 1 : 0;
+ ggml_set_op_params(result, &flag, sizeof(flag));
+
return result;
}
diff --git a/include/llama.h b/include/llama.h
index eb869814097..41f2d172dcf 100644
@@ -310,6 +310,9 @@ extern "C" {
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
+ // override arch from GGUF to load MTP as a separate ctx
+ const char * override_arch;
+
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
@@ -333,6 +336,7 @@ extern "C" {
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
uint32_t n_ubatch; // physical maximum batch size
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
+ uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback)
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
@@ -530,6 +534,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
+ LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
@@ -963,6 +968,20 @@ extern "C" {
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
+ // [EXPERIMENTAL] MTP APIs, accessors for hidden states
+ LLAMA_API struct ggml_tensor * llama_context_get_t_h_pre_norm(struct llama_context * ctx);
+ LLAMA_API struct ggml_tensor * llama_context_get_t_mtp_out (struct llama_context * ctx);
+
+ LLAMA_API void llama_set_mtp(
+ struct llama_context * ctx_target,
+ struct llama_context * ctx_mtp);
+
+ LLAMA_API bool llama_context_seq_rm(
+ struct llama_context * ctx,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1);
+
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 633a66fc665..3515f3e9722 100644
@@ -41,6 +41,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_QWEN35, "qwen35" },
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
+ { LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
+ { LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -756,14 +758,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- // NextN/MTP tensors are currently ignored (reserved for future MTP support)
- // These tensors only exist in the last layer(s) and are treated as output tensors
- {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
- {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
- {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
- {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
+ // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
+ // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
+ // the model loader doesn't fault on the block index.
+ {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
+ {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
+ {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+ {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -876,6 +879,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
}
}
+bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) {
+ switch (arch) {
+ case LLM_ARCH_QWEN35:
+ case LLM_ARCH_QWEN35MOE:
+ return true;
+ default:
+ return false;
+ }
+}
+
bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_GROK:
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 8f335f5c7b3..fa1b4aebb8d 100644
@@ -45,6 +45,8 @@ enum llm_arch {
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_QWEN35,
LLM_ARCH_QWEN35MOE,
+ LLM_ARCH_QWEN35_MTP,
+ LLM_ARCH_QWEN35MOE_MTP,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
@@ -636,3 +638,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
bool llm_arch_is_diffusion (const llm_arch & arch);
bool llm_arch_supports_sm_tensor(const llm_arch & arch);
+bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch);
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index d584415ee48..862dad1c7e2 100644
@@ -42,6 +42,13 @@ llama_context::llama_context(
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}
+ cparams.n_rs_seq = params.n_rs_seq;
+ if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) {
+ LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
+ __func__, cparams.n_rs_seq);
+ cparams.n_rs_seq = 0;
+ }
+
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
@@ -383,6 +390,9 @@ llama_context::~llama_context() {
}
}
}
+ if (mtp.hook_batch.pos != nullptr) {
+ llama_batch_free(mtp.hook_batch);
+ }
ggml_opt_free(opt_ctx);
}
@@ -1235,13 +1245,21 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
return nullptr;
}
+ if (mtp.ctx_mtp) {
+ handle_mtp_for_ubatch(
+ (int32_t) ubatch.n_tokens,
+ ubatch.token,
+ ubatch.pos,
+ res->t_h_pre_norm);
+ }
+
ret = GGML_STATUS_SUCCESS;
return res;
}
int llama_context::encode(const llama_batch & batch_inp) {
- GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
+ GGML_ASSERT(batch_inp.token || batch_inp.embd);
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -1531,7 +1549,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
}
int llama_context::decode(const llama_batch & batch_inp) {
- GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
+ GGML_ASSERT(batch_inp.token || batch_inp.embd);
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
@@ -2946,6 +2964,7 @@ llama_context_params llama_context_default_params() {
/*.n_batch =*/ 2048,
/*.n_ubatch =*/ 512,
/*.n_seq_max =*/ 1,
+ /*.n_rs_seq =*/ 0,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@@ -3092,6 +3111,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) {
return ctx->n_seq_max();
}
+uint32_t llama_n_rs_seq(const llama_context * ctx) {
+ return ctx->get_cparams().n_rs_seq;
+}
+
const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}
@@ -3139,6 +3162,126 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
ctx->set_warmup(warmup);
}
+ggml_tensor * llama_context::get_t_h_pre_norm() const {
+ return gf_res_prev ? gf_res_prev->t_h_pre_norm : nullptr;
+}
+
+ggml_tensor * llama_context_get_t_h_pre_norm(struct llama_context * ctx) {
+ return ctx ? ctx->get_t_h_pre_norm() : nullptr;
+}
+
+ggml_tensor * llama_context::get_t_mtp_out() const {
+ return gf_res_prev ? gf_res_prev->t_mtp_out : nullptr;
+}
+
+ggml_tensor * llama_context_get_t_mtp_out(struct llama_context * ctx) {
+ return ctx ? ctx->get_t_mtp_out() : nullptr;
+}
+
+void llama_set_mtp(struct llama_context * ctx_target, struct llama_context * ctx_mtp) {
+ if (!ctx_target) return;
+ ctx_target->set_mtp(ctx_mtp);
+}
+
+void llama_context::set_mtp(llama_context * ctx_mtp_in) {
+ if (mtp.ctx_mtp == ctx_mtp_in) return;
+
+ if (mtp.hook_batch.pos != nullptr) {
+ llama_batch_free(mtp.hook_batch);
+ mtp.hook_batch = llama_batch{};
+ }
+
+ mtp.ctx_mtp = ctx_mtp_in;
+ mtp.pending_pos = -1;
+
+ if (mtp.ctx_mtp) {
+ const int32_t n_ub = (int32_t) cparams.n_ubatch;
+ const int32_t n_embd = (int32_t) model.hparams.n_embd;
+ mtp.hook_batch = llama_batch_init(n_ub, n_embd, 1);
+ mtp.hook_batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub);
+ mtp.pending_h.assign(n_embd, 0.0f);
+ LLAMA_LOG_INFO("%s: MTP draft head registered (ctx_mtp=%p, n_ubatch=%d, n_embd=%d)\n",
+ __func__, (const void *) mtp.ctx_mtp, n_ub, n_embd);
+ } else {
+ mtp.pending_h.clear();
+ mtp.pending_h.shrink_to_fit();
+ LLAMA_LOG_INFO("%s: MTP draft head unregistered\n", __func__);
+ }
+}
+
+void llama_context::handle_mtp_for_ubatch(
+ int32_t n_tokens,
+ const llama_token * tokens,
+ const llama_pos * positions,
+ struct ggml_tensor * t) {
+ if (n_tokens == 0 || t == nullptr) {
+ return;
+ }
+ if (t->ne[1] != (int64_t) n_tokens) {
+ return;
+ }
+ const int64_t n_embd = model.hparams.n_embd;
+ GGML_ASSERT(t->ne[0] == n_embd);
+
+ const int n_rows = (int) n_tokens;
+ const llama_pos pos_start = positions[0];
+
+ const llama_pos pos_max_mtp = llama_memory_seq_pos_max(llama_get_memory(mtp.ctx_mtp), 0);
+ if (pos_start <= pos_max_mtp) {
+ return;
+ }
+
+ const bool pending_continues = mtp.pending_pos >= 0 && mtp.pending_pos + 1 == pos_start;
+ if (mtp.pending_pos >= 0 && !pending_continues) {
+ mtp.pending_pos = -1;
+ }
+
+ synchronize();
+
+ const size_t row_bytes = (size_t) n_embd * sizeof(float);
+ const int n_out = (pending_continues ? 1 : 0) + (n_rows - 1);
+
+ if (n_out > 0) {
+ int out_idx = 0;
+ if (pending_continues) {
+ std::memcpy(mtp.hook_batch.embd + (size_t) out_idx * n_embd,
+ mtp.pending_h.data(), row_bytes);
+ mtp.hook_batch.token[out_idx] = tokens[0];
+ mtp.hook_batch.pos[out_idx] = pos_start;
+ mtp.hook_batch.n_seq_id[out_idx] = 1;
+ mtp.hook_batch.seq_id[out_idx][0] = 0;
+ mtp.hook_batch.logits[out_idx] = 0;
+ ++out_idx;
+ }
+ for (int k = 0; k + 1 < n_rows; ++k) {
+ ggml_backend_tensor_get(t,
+ mtp.hook_batch.embd + (size_t) out_idx * n_embd,
+ (size_t) k * row_bytes,
+ row_bytes);
+ mtp.hook_batch.token[out_idx] = tokens[k + 1];
+ mtp.hook_batch.pos[out_idx] = positions[k + 1];
+ mtp.hook_batch.n_seq_id[out_idx] = 1;
+ mtp.hook_batch.seq_id[out_idx][0] = 0;
+ mtp.hook_batch.logits[out_idx] = 0;
+ ++out_idx;
+ }
+ GGML_ASSERT(out_idx == n_out);
+ mtp.hook_batch.n_tokens = n_out;
+
+ const int32_t rc_dec = llama_decode(mtp.ctx_mtp, mtp.hook_batch);
+ if (rc_dec != 0) {
+ LLAMA_LOG_ERROR("%s: llama_decode(ctx_mtp) failed rc=%d (pos=%d, n=%d)\n",
+ __func__, (int) rc_dec, (int) pos_start, n_out);
+ }
+ }
+
+ // Stash the last h-row as the new pending (for the next ubatch's first
+ // token to pair with).
+ ggml_backend_tensor_get(t, mtp.pending_h.data(),
+ (size_t) (n_rows - 1) * row_bytes, row_bytes);
+ mtp.pending_pos = pos_start + n_rows - 1;
+}
+
void llama_synchronize(llama_context * ctx) {
ctx->synchronize();
}
@@ -3296,6 +3439,22 @@ bool llama_memory_seq_rm(
return mem->seq_rm(seq_id, p0, p1);
}
+bool llama_context_seq_rm(
+ struct llama_context * ctx,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1) {
+ if (!ctx) {
+ return true;
+ }
+ const bool ok = llama_memory_seq_rm(llama_get_memory(ctx), seq_id, p0, p1);
+
+ if (llama_context * ctx_mtp = ctx->get_mtp()) {
+ llama_memory_seq_rm(llama_get_memory(ctx_mtp), 0, p0, p1);
+ }
+ return ok;
+}
+
void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
diff --git a/src/llama-context.h b/src/llama-context.h
index 53c705eaffc..4c4e04fdca4 100644
@@ -6,6 +6,7 @@
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
+#include "llama-mtp.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@@ -69,6 +70,12 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
+ ggml_tensor * get_t_h_pre_norm() const;
+ ggml_tensor * get_t_mtp_out() const;
+
+ void set_mtp(llama_context * ctx_mtp_in);
+ llama_context * get_mtp() const { return mtp.ctx_mtp; }
+
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
@@ -233,6 +240,12 @@ struct llama_context {
llm_graph_cb graph_get_cb() const;
+ void handle_mtp_for_ubatch(
+ int32_t n_tokens,
+ const llama_token * tokens,
+ const llama_pos * positions,
+ struct ggml_tensor * t_h_pre_norm);
+
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
@@ -253,6 +266,8 @@ struct llama_context {
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
+ llama_mtp mtp;
+
std::unique_ptr<llama_memory_i> memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
diff --git a/src/llama-cparams.h b/src/llama-cparams.h
index 9d359474132..27aff3a230c 100644
@@ -12,6 +12,7 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
+ uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 2ff23f87cf4..f97f5254c3a 100644
@@ -3,6 +3,7 @@
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-batch.h"
+#include "llama-context.h"
#include "llama-cparams.h"
#include "llama-kv-cache.h"
@@ -2523,7 +2524,8 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows) const {
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
+ GGML_UNUSED(rs_size);
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
diff --git a/src/llama-graph.h b/src/llama-graph.h
index 5cb1756c6a9..d5f453dc1be 100644
@@ -18,6 +18,7 @@ struct ggml_tensor;
struct llama_cparams;
struct llama_layer;
+struct llama_context;
struct llama_memory_context_i;
@@ -645,6 +646,8 @@ class llm_graph_result {
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
+ ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
+
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
@@ -673,6 +676,10 @@ class llm_graph_result {
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
+ // MTP related inputs/outputs
+ ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state required for MTP
+ ggml_tensor * t_mtp_out = nullptr; // [n_embd, n_tokens]
+
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
std::map<llama_seq_id, ggml_tensor*> t_sampled;
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
index 002d15d415f..b4d21b6ec9d 100644
@@ -229,6 +229,10 @@ uint32_t llama_hparams::n_embd_head_v_mla() const {
}
bool llama_hparams::has_kv(uint32_t il) const {
+ if (kv_only_nextn) {
+ return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers);
+ }
+
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index ac7f9ee8650..1dcad365a79 100644
@@ -92,6 +92,8 @@ struct llama_hparams {
uint32_t moe_latent_size = 0;
uint32_t nextn_predict_layers = 0;
+ bool kv_only_nextn = false;
+
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;
diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp
index 10e6b459797..a59561ea54d 100644
@@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
+ uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
@@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
offload,
rs_size,
n_seq_max,
+ n_rs_seq,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h
index 807c8aac96c..c9d3f9f57c5 100644
@@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i {
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
+ uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
index 4ce1af592c1..fd305cab79c 100644
@@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid(
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
+ uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
@@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload,
rs_size,
n_seq_max,
+ n_rs_seq,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h
index 558cafdf984..484eafb7499 100644
@@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i {
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
+ uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp
index 4b4fdeb6dec..eda3d2cf203 100644
@@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent(
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
+ uint32_t n_rs_seq,
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
@@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent(
size = mem_size;
used = 0;
+ this->n_rs_seq = n_rs_seq;
+ rs_idx.assign(n_seq_max, 0);
+
cells.clear();
cells.resize(mem_size);
@@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent(
throw std::runtime_error("failed to create ggml context for rs cache");
}
- ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size);
- ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size);
+ const uint32_t n_rows = mem_size * (1 + n_rs_seq);
+ ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows);
+ ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows);
ggml_format_name(r, "cache_r_l%d", i);
ggml_format_name(s, "cache_s_l%d", i);
r_l[i] = r;
@@ -133,6 +138,8 @@ void llama_memory_recurrent::clear(bool data) {
head = 0;
used = 0;
+ std::fill(rs_idx.begin(), rs_idx.end(), 0);
+
if (data) {
for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
@@ -141,7 +148,6 @@ void llama_memory_recurrent::clear(bool data) {
}
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
- //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;
if (p0 < 0) {
@@ -161,15 +167,22 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (0 <= seq_id) {
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
- const auto & cell = cells[tail_id];
- // partial intersection is invalid if it includes the final pos
+ auto & cell = cells[tail_id];
+
+ // partial rollback via per-token snapshot index (bounded by n_rs_seq)
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
- //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
+ const llama_pos rollback = cell.pos - (p0 - 1);
+ if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) {
+ set_rs_idx(seq_id, (uint32_t) rollback);
+ cell.pos = p0 - 1;
+ return true;
+ }
return false;
}
// invalidate tails which will be cleared
if (p0 <= cell.pos && cell.pos < p1) {
tail_id = -1;
+ set_rs_idx(seq_id, 0);
}
}
} else {
@@ -368,6 +381,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
+void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) {
+ if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) {
+ return;
+ }
+ rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx;
+}
+
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const auto & [_, buf] : ctxs_bufs) {
@@ -1159,5 +1179,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
}
int32_t llama_memory_recurrent_context::s_copy(int i) const {
- return mem->cells[i + mem->head].src0;
+ const uint32_t cell_idx = i + mem->head;
+ const int32_t src0 = mem->cells[cell_idx].src0;
+
+ if (mem->n_rs_seq == 0) {
+ return src0;
+ }
+
+ uint32_t idx = 0;
+ if (!mem->cells[cell_idx].seq_id.empty()) {
+ const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin();
+ if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) {
+ idx = mem->rs_idx[seq];
+ // reset rollback idx
+ mem->rs_idx[seq] = 0;
+ }
+ }
+ return (int32_t)(idx * mem->size) + src0;
}
diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h
index 47f01d73912..29c58afc9c2 100644
@@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i {
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
+ uint32_t n_rs_seq,
const layer_filter_cb & filter);
~llama_memory_recurrent() = default;
@@ -69,6 +70,13 @@ class llama_memory_recurrent : public llama_memory_i {
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
+ // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups
+ uint32_t n_rs_seq = 0;
+ // per-seq rollback index
+ std::vector<uint32_t> rs_idx;
+
+ void set_rs_idx(llama_seq_id seq_id, uint32_t idx);
+
// computed before each graph build
uint32_t n = 0;
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
index 4e65a45a50d..c645d0785ab 100644
@@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
return tensor;
}
-void llama_model_loader::done_getting_tensors() const {
- if (n_created != n_tensors) {
- throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
+void llama_model_loader::done_getting_tensors(bool partial) const {
+ if (n_created > n_tensors) {
+ throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created));
+ }
+ if (n_created < n_tensors) {
+ if (!partial) {
+ throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
+ }
+ LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n",
+ __func__, n_created, n_tensors);
}
if (n_tensors_moved > 0) {
LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n",
diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h
index 7b3d6703c03..c476026d3e5 100644
@@ -184,7 +184,7 @@ struct llama_model_loader {
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
- void done_getting_tensors() const;
+ void done_getting_tensors(bool partial = false) const;
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 9a5802e3242..540916e98d7 100644
@@ -276,6 +276,10 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_qwen35(params);
case LLM_ARCH_QWEN35MOE:
return new llama_model_qwen35moe(params);
+ case LLM_ARCH_QWEN35_MTP:
+ return new llama_model_qwen35_mtp(params);
+ case LLM_ARCH_QWEN35MOE_MTP:
+ return new llama_model_qwen35moe_mtp(params);
case LLM_ARCH_MISTRAL3:
return new llama_model_mistral3(params);
case LLM_ARCH_MIMO2:
@@ -309,6 +313,15 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para
if (arch == LLM_ARCH_UNKNOWN) {
throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'");
}
+ if (params.override_arch != nullptr && params.override_arch[0] != '\0') {
+ const llm_arch override = llm_arch_from_string(params.override_arch);
+ if (override == LLM_ARCH_UNKNOWN) {
+ throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'");
+ }
+ LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n",
+ __func__, llm_arch_name(arch), llm_arch_name(override));
+ arch = override;
+ }
return llama_model_create(arch, params);
}
@@ -1400,7 +1413,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
}
}
- ml.done_getting_tensors();
+ const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP);
+ ml.done_getting_tensors(partial_load);
// populate tensors_by_name
for (auto & [_, ctx_ptr] : ml.ctx_map) {
@@ -1945,6 +1959,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max,
+ cparams.n_rs_seq,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
// The main difference between hybrid architectures is the
@@ -1978,6 +1993,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_type_s */ GGML_TYPE_F32,
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
+ /* n_rs_seq */ cparams.n_rs_seq,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
@@ -1996,6 +2012,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
+ /* n_rs_seq */ cparams.n_rs_seq,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
@@ -2092,6 +2109,7 @@ llama_model_params llama_model_default_params() {
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr,
+ /*.override_arch =*/ nullptr,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_direct_io =*/ false,
@@ -2316,6 +2334,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_QWEN35:
case LLM_ARCH_QWEN35MOE:
+ case LLM_ARCH_QWEN35_MTP:
+ case LLM_ARCH_QWEN35MOE_MTP:
return LLAMA_ROPE_TYPE_IMROPE;
case LLM_ARCH_GLM4:
diff --git a/src/llama-mtp.h b/src/llama-mtp.h
new file mode 100644
index 00000000000..65fb3b110c2
@@ -0,0 +1,17 @@
+#pragma once
+
+#include "llama.h"
+
+#include <vector>
+
+struct llama_mtp {
+ llama_context * ctx_mtp = nullptr; // non-owning
+ llama_batch hook_batch = {}; // sized to n_ubatch
+
+ // Cross-ubatch shift state: pair (h_p, x_{p+1}) at MTP pos p+1. The last
+ // h-row of one ubatch needs the first token of the NEXT ubatch to pair
+ // with, so it's stashed here until that next ubatch fires. Resets when
+ // pos_start of the new ubatch != pending_pos+1 (new prompt or seq_rm gap).
+ std::vector<float> pending_h;
+ llama_pos pending_pos = -1;
+};
diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp
index 6bc989c9509..6fd9df17db3 100644
@@ -397,7 +397,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
- ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
+ ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*keep_intermediates=*/false);
if (n_tokens == 1) {
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
} else {
@@ -420,6 +420,42 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
return {output, new_state};
}
+ggml_tensor * llm_build_delta_net_base::build_delta_net_fused_keep_intermediates(
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ ggml_tensor * s,
+ int il) {
+ const int64_t S_k = q->ne[0];
+ const int64_t H_k = q->ne[1];
+ const int64_t n_tokens = q->ne[2];
+ const int64_t n_seqs = q->ne[3];
+
+ const int64_t S_v = v->ne[0];
+ const int64_t H_v = v->ne[1];
+
+ GGML_ASSERT(S_k == S_v);
+ GGML_ASSERT(H_v % H_k == 0);
+
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+ GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+ GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
+ GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+ GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+ GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+
+ ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*keep_intermediates=*/true);
+ // Always n_tokens > 1 here (gated by n_seq_tokens > 1 at the call site),
+ // so this is the chunked variant by the same naming convention used in
+ // build_delta_net_fused.
+ cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
+ return result;
+}
+
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net(
ggml_tensor * q,
ggml_tensor * k,
diff --git a/src/models/models.h b/src/models/models.h
index 6d5f18a8e20..f416045e5db 100644
@@ -56,6 +56,17 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * s,
int il);
+ // fused op with keep_intermediates=true: returns the raw [attn | T snapshots]
+ // output tensor. Caller slices snapshot views and routes them to recurrent slots.
+ ggml_tensor * build_delta_net_fused_keep_intermediates(
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ ggml_tensor * s,
+ int il);
+
// choose one of two implementations above based on the number of tokens
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net(
ggml_tensor * q,
@@ -1785,6 +1796,32 @@ struct llama_model_qwen35moe : public llama_model_base {
};
+struct llama_model_qwen35_mtp : public llama_model_base {
+ llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {}
+ void load_arch_hparams(llama_model_loader & ml) override;
+ void load_arch_tensors(llama_model_loader & ml) override;
+
+ struct graph : public llm_graph_context {
+ graph(const llama_model & model, const llm_graph_params & params);
+ };
+
+ std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
+};
+
+
+struct llama_model_qwen35moe_mtp : public llama_model_base {
+ llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {}
+ void load_arch_hparams(llama_model_loader & ml) override;
+ void load_arch_tensors(llama_model_loader & ml) override;
+
+ struct graph : public llm_graph_context {
+ graph(const llama_model & model, const llm_graph_params & params);
+ };
+
+ std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
+};
+
+
struct llama_model_mistral3 : public llama_model_base {
llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp
index f276be61ba8..e8ac92331a4 100644
@@ -12,16 +12,23 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
- // Mark recurrent layers (linear attention layers)
+ // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
+ ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
+ GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
+ hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
+
+ // Mark recurrent layers (linear attention layers). MTP layers are dense
+ // attention-only and must be flagged non-recurrent.
{
+ const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
- hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
+ hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
}
}
- switch (hparams.n_layer) {
+ switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break;
case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break;
case 64: type = LLM_TYPE_27B; break;
@@ -83,6 +90,16 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+
+ // NextN/MTP tensors (preserved but unused) - only bound on MTP layers
+ if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+ layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
+ layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ }
}
}
@@ -111,7 +128,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
- for (int il = 0; il < n_layer; ++il) {
+ // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
+ const int n_transformer_layers = n_layer - (int)hparams.nextn_predict_layers;
+ for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@@ -128,7 +147,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
- if (il == n_layer - 1 && inp_out_ids) {
+ if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -160,6 +179,11 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
}
cur = inpL;
+ if (hparams.nextn_predict_layers > 0) {
+ cb(cur, "h_pre_norm", -1);
+ res->t_h_pre_norm = cur;
+ }
+
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
@@ -303,6 +327,11 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+ const uint32_t mem_size = mctx_cur->get_size();
+ const bool keep_intermediates = (cparams.n_rs_seq > 0)
+ && (n_seq_tokens > 1)
+ && ((uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq);
+
// Input projections
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
@@ -350,19 +379,37 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
- // Update convolution state cache
- // Extract the last (conv_kernel_size - 1) states from conv_input
- ggml_tensor * last_conv_states =
- ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
- conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
- cb(last_conv_states, "last_conv_states", il);
-
- ggml_tensor * state_update_target =
- ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
- kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
- cb(state_update_target, "state_update_target", il);
-
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+ if (!keep_intermediates) {
+ // Update convolution state cache.
+ // Extract the last (conv_kernel_size - 1) states from conv_input
+ ggml_tensor * last_conv_states =
+ ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+ conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+ cb(last_conv_states, "last_conv_states", il);
+
+ ggml_tensor * state_update_target =
+ ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
+ kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+ cb(state_update_target, "state_update_target", il);
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+ } else {
+ // store per-token intermediates
+ const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
+ const size_t row_size = row_count * ggml_element_size(conv_states_all);
+ for (int64_t t = 1; t <= n_seq_tokens; ++t) {
+ const uint32_t slot = (uint32_t)(n_seq_tokens - t);
+ ggml_tensor * src =
+ ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs,
+ conv_input->nb[1], conv_input->nb[2],
+ t * ggml_element_size(conv_input));
+ ggml_tensor * dst =
+ ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs,
+ conv_states_all->nb[1],
+ ((size_t) slot * mem_size + kv_head) * row_size);
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
+ }
+ }
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -413,7 +460,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
- // note: need explicit repeat only if we are not using the fused GDN
+ // note: need explicit repeat only if we are not using the fused GDN.
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
@@ -424,18 +471,54 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
- auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+ ggml_tensor * output;
+
+ if (!keep_intermediates) {
+ auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
+ output = attn_out.first;
+ ggml_tensor * new_state = attn_out.second;
+ cb(output, "attn_output", il);
+ cb(new_state, "new_state", il);
+
+ // Update the recurrent states (slot 0 only).
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0, new_state,
+ ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
+ kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+ } else {
+ ggml_tensor * gdn_out = build_delta_net_fused_keep_intermediates(
+ q_conv, k_conv, v_conv, gate, beta, state, il);
- ggml_tensor * output = attn_out.first;
- ggml_tensor * new_state = attn_out.second;
- cb(output, "attn_output", il);
- cb(new_state, "new_state", il);
+ const int64_t S_v = head_v_dim;
+ const int64_t H_v = num_v_heads;
- // Update the recurrent states
- ggml_build_forward_expand(gf,
- ggml_cpy(ctx0, new_state,
- ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
- kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+ const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs;
+ const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
+
+ output = ggml_view_4d(ctx0, gdn_out,
+ S_v, H_v, n_seq_tokens, n_seqs,
+ ggml_row_size(gdn_out->type, S_v),
+ ggml_row_size(gdn_out->type, S_v * H_v),
+ ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens),
+ 0);
+ cb(output, "attn_output", il);
+
+ const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
+ for (int64_t t = 1; t <= n_seq_tokens; ++t) {
+ const uint32_t slot = (uint32_t)(n_seq_tokens - t);
+ ggml_tensor * src = ggml_view_4d(ctx0, gdn_out,
+ S_v, S_v, H_v, n_seqs,
+ ggml_row_size(gdn_out->type, S_v),
+ ggml_row_size(gdn_out->type, S_v * S_v),
+ ggml_row_size(gdn_out->type, S_v * S_v * H_v),
+ ggml_row_size(gdn_out->type, attn_score_elems + (t - 1) * state_size_per_snap));
+ ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
+ hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
+ ((size_t) slot * mem_size + kv_head) * row_size);
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
+ }
+ }
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
diff --git a/src/models/qwen35_mtp.cpp b/src/models/qwen35_mtp.cpp
new file mode 100644
index 00000000000..2ccc5f62772
@@ -0,0 +1,205 @@
+#include "models.h"
+
+void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
+
+ ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
+ GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0");
+ GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer);
+
+ // only the MTP layers get a KV cache, trunk layers are skipped.
+ hparams.kv_only_nextn = true;
+ hparams.n_layer_kv_from_start = -1;
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+ hparams.recurrent_layer_arr[i] = false;
+ }
+
+ type = LLM_TYPE_UNKNOWN;
+}
+
+void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) {
+ LLAMA_LOAD_LOCALS;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ if (output == nullptr) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+ }
+
+ const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
+ for (int i = 0; i < n_layer; ++i) {
+ if (static_cast<uint32_t>(i) < n_main) {
+ continue; // trunk layer — owned by the sibling QWEN35 model
+ }
+
+ auto & layer = layers[i];
+
+ // MTP block looks like a full-attention Qwen3.5 decoder block.
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+
+ create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
+
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+
+ // NextN-specific tensors that define the MTP block.
+ layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0);
+ layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0);
+ layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0);
+ layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ }
+}
+
+std::unique_ptr<llm_graph_context> llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const {
+ return std::make_unique<graph>(*this, params);
+}
+
+// LLM_ARCH_QWEN35_MTP draft head for Qwen35-6 series
+llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params)
+ : llm_graph_context(params) {
+ GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0");
+ GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block");
+
+ const int64_t n_embd_head = hparams.n_embd_head_v();
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+ // The MTP block lives at the source file's original layer index.
+ const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
+ const auto & layer = model.layers[il];
+
+ GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
+ GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
+ GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
+
+ int sections[4];
+ std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+ auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
+
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ ggml_set_input(inp->tokens);
+
+ inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
+ ggml_set_input(inp->embd);
+ ggml_set_name(inp->embd, "mtp_h_input");
+
+ ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
+
+ ggml_tensor * h_input = inp->embd;
+ ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
+ cb(tok_embd, "mtp_tok_embd", il);
+
+ res->add_input(std::move(inp));
+
+ ggml_tensor * inp_pos = build_inp_pos();
+ auto * inp_attn = build_attn_inp_kv();
+
+ ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
+ cb(h_norm, "mtp_hnorm", il);
+
+ ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
+ cb(e_norm, "mtp_enorm", il);
+
+ ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
+ cb(concat, "mtp_concat", il);
+
+ ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
+ cb(cur, "mtp_eh_proj", il);
+
+ ggml_tensor * inpSA = cur;
+
+ cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
+ cb(cur, "mtp_attn_norm", il);
+
+ ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
+ cb(Qcur_full, "mtp_Qcur_full", il);
+
+ ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
+ n_embd_head, n_head, n_tokens,
+ ggml_element_size(Qcur_full) * n_embd_head * 2,
+ ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+ 0);
+ Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
+ cb(Qcur, "mtp_Qcur_normed", il);
+
+ ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
+ n_embd_head, n_head, n_tokens,
+ ggml_element_size(Qcur_full) * n_embd_head * 2,
+ ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+ ggml_element_size(Qcur_full) * n_embd_head);
+ gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+ cb(gate, "mtp_gate", il);
+
+ ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
+ cb(Kcur, "mtp_Kcur_normed", il);
+
+ ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+ cb(Vcur, "mtp_Vcur", il);
+
+ Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
+ n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
+ n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ const float kq_scale = hparams.f_attention_scale == 0.0f
+ ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+ cur = build_attn(inp_attn,
+ nullptr, nullptr, nullptr,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+ cb(cur, "mtp_attn_pregate", il);
+
+ cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
+ cur = build_lora_mm(layer.wo, cur, layer.wo_s);
+ cb(cur, "mtp_attn_out", il);
+
+ cur = ggml_add(ctx0, cur, inpSA);
+ cb(cur, "mtp_attn_residual", il);
+
+ ggml_tensor * ffn_residual = cur;
+ cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
+ cb(cur, "mtp_attn_post_norm", il);
+
+ cur = build_ffn(cur,
+ layer.ffn_up, nullptr, layer.ffn_up_s,
+ layer.ffn_gate, nullptr, layer.ffn_gate_s,
+ layer.ffn_down, nullptr, layer.ffn_down_s,
+ nullptr,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(cur, "mtp_ffn_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_residual);
+ cb(cur, "mtp_post_ffn", il);
+
+ // snapshot the MTP block's post-FFN hidden for AR loop for when MTP tokens > 1
+ res->t_mtp_out = cur;
+
+ ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
+ ? layer.nextn.shared_head_norm
+ : model.output_norm;
+ GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm");
+ cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
+ cb(cur, "mtp_shared_head_norm", -1);
+
+ ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
+ GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)");
+ cur = build_lora_mm(head_w, cur);
+ cb(cur, "result_output", -1);
+
+ res->t_logits = cur;
+ ggml_build_forward_expand(gf, cur);
+}
diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp
index cf05dc9d61c..38d0998761f 100644
@@ -15,16 +15,23 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
- // Mark recurrent layers (linear attention layers)
+ // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
+ ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
+ GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
+ hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
+
+ // Mark recurrent layers (linear attention layers). MTP layers are dense
+ // attention-only and must be flagged non-recurrent.
{
+ const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
- hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
+ hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
}
}
- switch (hparams.n_layer) {
+ switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 40: type = LLM_TYPE_35B_A3B; break;
case 48: type = LLM_TYPE_122B_A10B; break;
case 60: type = LLM_TYPE_397B_A17B; break;
@@ -96,6 +103,16 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
+
+ // NextN/MTP tensors (preserved but unused) - only bound on MTP layers
+ if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+ layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
+ layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ }
}
}
@@ -124,7 +141,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
- for (int il = 0; il < n_layer; ++il) {
+ // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
+ const int n_transformer_layers = n_layer - (int)hparams.nextn_predict_layers;
+ for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@@ -141,7 +160,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
- if (il == n_layer - 1 && inp_out_ids) {
+ if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -173,6 +192,11 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
}
cur = inpL;
+ if (hparams.nextn_predict_layers > 0) {
+ cb(cur, "h_pre_norm", -1);
+ res->t_h_pre_norm = cur;
+ }
+
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
@@ -316,6 +340,11 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+ const uint32_t mem_size = mctx_cur->get_size();
+ const bool keep_intermediates = (cparams.n_rs_seq > 0)
+ && (n_seq_tokens > 1)
+ && ((uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq);
+
// Input projections
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
@@ -363,19 +392,37 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
- // Update convolution state cache
- // Extract the last (conv_kernel_size - 1) states from conv_input
- ggml_tensor * last_conv_states =
- ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
- conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
- cb(last_conv_states, "last_conv_states", il);
+ if (!keep_intermediates) {
+ // Update convolution state cache.
+ // Extract the last (conv_kernel_size - 1) states from conv_input
+ ggml_tensor * last_conv_states =
+ ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+ conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+ cb(last_conv_states, "last_conv_states", il);
- ggml_tensor * state_update_target =
- ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
- kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
- cb(state_update_target, "state_update_target", il);
+ ggml_tensor * state_update_target =
+ ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
+ kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+ cb(state_update_target, "state_update_target", il);
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+ } else {
+ // store per-token intermediates
+ const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
+ const size_t row_size = row_count * ggml_element_size(conv_states_all);
+ for (int64_t t = 1; t <= n_seq_tokens; ++t) {
+ const uint32_t slot = (uint32_t)(n_seq_tokens - t);
+ ggml_tensor * src =
+ ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs,
+ conv_input->nb[1], conv_input->nb[2],
+ t * ggml_element_size(conv_input));
+ ggml_tensor * dst =
+ ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs,
+ conv_states_all->nb[1],
+ ((size_t) slot * mem_size + kv_head) * row_size);
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
+ }
+ }
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -426,7 +473,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
- // note: need explicit repeat only if we are not using the fused GDN
+ // note: need explicit repeat only if we are not using the fused GDN.
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
@@ -437,18 +484,54 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
- auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+ ggml_tensor * output;
- ggml_tensor * output = attn_out.first;
- ggml_tensor * new_state = attn_out.second;
- cb(output, "attn_output", il);
- cb(new_state, "new_state", il);
+ if (!keep_intermediates) {
+ auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
- // Update the recurrent states
- ggml_build_forward_expand(gf,
- ggml_cpy(ctx0, new_state,
- ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
- kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+ output = attn_out.first;
+ ggml_tensor * new_state = attn_out.second;
+ cb(output, "attn_output", il);
+ cb(new_state, "new_state", il);
+
+ // Update the recurrent states (slot 0 only).
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0, new_state,
+ ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
+ kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+ } else {
+ ggml_tensor * gdn_out = build_delta_net_fused_keep_intermediates(
+ q_conv, k_conv, v_conv, gate, beta, state, il);
+
+ const int64_t S_v = head_v_dim;
+ const int64_t H_v = num_v_heads;
+
+ const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs;
+ const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
+
+ output = ggml_view_4d(ctx0, gdn_out,
+ S_v, H_v, n_seq_tokens, n_seqs,
+ ggml_row_size(gdn_out->type, S_v),
+ ggml_row_size(gdn_out->type, S_v * H_v),
+ ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens),
+ 0);
+ cb(output, "attn_output", il);
+
+ const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
+ for (int64_t t = 1; t <= n_seq_tokens; ++t) {
+ const uint32_t slot = (uint32_t)(n_seq_tokens - t);
+ ggml_tensor * src = ggml_view_4d(ctx0, gdn_out,
+ S_v, S_v, H_v, n_seqs,
+ ggml_row_size(gdn_out->type, S_v),
+ ggml_row_size(gdn_out->type, S_v * S_v),
+ ggml_row_size(gdn_out->type, S_v * S_v * H_v),
+ ggml_row_size(gdn_out->type, attn_score_elems + (t - 1) * state_size_per_snap));
+ ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
+ hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
+ ((size_t) slot * mem_size + kv_head) * row_size);
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
+ }
+ }
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
diff --git a/src/models/qwen35moe_mtp.cpp b/src/models/qwen35moe_mtp.cpp
new file mode 100644
index 00000000000..2f8db48adc1
@@ -0,0 +1,250 @@
+#include "models.h"
+
+void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) {
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
+
+ ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
+ GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0");
+ GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer);
+ GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0");
+
+ // only the MTP layers get a KV cache, trunk layers are skipped.
+ hparams.kv_only_nextn = true;
+ hparams.n_layer_kv_from_start = -1;
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+ hparams.recurrent_layer_arr[i] = false;
+ }
+
+ type = LLM_TYPE_UNKNOWN;
+}
+
+void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) {
+ LLAMA_LOAD_LOCALS;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ if (output == nullptr) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+ }
+
+ const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+ const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
+
+ const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
+ for (int i = 0; i < n_layer; ++i) {
+ if (static_cast<uint32_t>(i) < n_main) {
+ continue; // trunk layer — owned by the sibling QWEN35MOE model
+ }
+
+ auto & layer = layers[i];
+
+ // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN.
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+
+ create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
+
+ // Routed experts
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
+ create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
+
+ // Shared experts
+ layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
+
+ // NextN-specific tensors that define the MTP block.
+ layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0);
+ layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0);
+ layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0);
+ layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ }
+}
+
+std::unique_ptr<llm_graph_context> llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const {
+ return std::make_unique<graph>(*this, params);
+}
+
+// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE
+llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params)
+ : llm_graph_context(params) {
+ GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0");
+ GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block");
+
+ const int64_t n_embd_head = hparams.n_embd_head_v();
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+ const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
+ const auto & layer = model.layers[il];
+
+ GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
+ GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
+ GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
+ GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp");
+
+ int sections[4];
+ std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+ auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
+
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ ggml_set_input(inp->tokens);
+
+ inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
+ ggml_set_input(inp->embd);
+ ggml_set_name(inp->embd, "mtp_h_input");
+
+ ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
+
+ ggml_tensor * h_input = inp->embd;
+ ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
+ cb(tok_embd, "mtp_tok_embd", il);
+
+ res->add_input(std::move(inp));
+
+ ggml_tensor * inp_pos = build_inp_pos();
+ auto * inp_attn = build_attn_inp_kv();
+
+ ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
+ cb(h_norm, "mtp_hnorm", il);
+
+ ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
+ cb(e_norm, "mtp_enorm", il);
+
+ ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
+ cb(concat, "mtp_concat", il);
+
+ ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
+ cb(cur, "mtp_eh_proj", il);
+
+ ggml_tensor * inpSA = cur;
+
+ cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
+ cb(cur, "mtp_attn_norm", il);
+
+ ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
+ cb(Qcur_full, "mtp_Qcur_full", il);
+
+ ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
+ n_embd_head, n_head, n_tokens,
+ ggml_element_size(Qcur_full) * n_embd_head * 2,
+ ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+ 0);
+ Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
+ cb(Qcur, "mtp_Qcur_normed", il);
+
+ ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
+ n_embd_head, n_head, n_tokens,
+ ggml_element_size(Qcur_full) * n_embd_head * 2,
+ ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+ ggml_element_size(Qcur_full) * n_embd_head);
+ gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+ cb(gate, "mtp_gate", il);
+
+ ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
+ cb(Kcur, "mtp_Kcur_normed", il);
+
+ ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+ cb(Vcur, "mtp_Vcur", il);
+
+ Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
+ n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
+ n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ const float kq_scale = hparams.f_attention_scale == 0.0f
+ ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+ cur = build_attn(inp_attn,
+ nullptr, nullptr, nullptr,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+ cb(cur, "mtp_attn_pregate", il);
+
+ cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
+ cur = build_lora_mm(layer.wo, cur, layer.wo_s);
+ cb(cur, "mtp_attn_out", il);
+
+ cur = ggml_add(ctx0, cur, inpSA);
+ cb(cur, "mtp_attn_residual", il);
+
+ ggml_tensor * ffn_residual = cur;
+ cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
+ cb(cur, "mtp_attn_post_norm", il);
+
+ // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe).
+ ggml_tensor * moe_out =
+ build_moe_ffn(cur,
+ layer.ffn_gate_inp,
+ layer.ffn_up_exps,
+ layer.ffn_gate_exps,
+ layer.ffn_down_exps,
+ nullptr,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU, true,
+ hparams.expert_weights_scale,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+ nullptr, layer.ffn_gate_up_exps,
+ layer.ffn_up_exps_s,
+ layer.ffn_gate_exps_s,
+ layer.ffn_down_exps_s);
+ cb(moe_out, "mtp_ffn_moe_out", il);
+
+ if (layer.ffn_up_shexp != nullptr) {
+ ggml_tensor * ffn_shexp =
+ build_ffn(cur,
+ layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s,
+ layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s,
+ layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s,
+ nullptr,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(ffn_shexp, "mtp_ffn_shexp", il);
+
+ ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur);
+ shared_gate = ggml_sigmoid(ctx0, shared_gate);
+ cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il);
+
+ ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+ cb(ffn_shexp, "mtp_ffn_shexp_gated", il);
+
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
+ } else {
+ cur = moe_out;
+ }
+ cb(cur, "mtp_ffn_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_residual);
+ cb(cur, "mtp_post_ffn", il);
+
+ res->t_mtp_out = cur;
+
+ ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
+ ? layer.nextn.shared_head_norm
+ : model.output_norm;
+ GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm");
+ cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
+ cb(cur, "mtp_shared_head_norm", -1);
+
+ ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
+ GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)");
+ cur = build_lora_mm(head_w, cur);
+ cb(cur, "result_output", -1);
+
+ res->t_logits = cur;
+ ggml_build_forward_expand(gf, cur);
+}