Expand description
Safe wrapper around the C++ MTP draft session.
MtpSession pairs a target LlamaContext with an MTP draft
LlamaContext (built with
crate::context::params::LlamaContextType::Mtp) and drives the
multi-token-prediction speculative-decoding loop introduced in upstream
llama.cpp PR #22673.
The draft algorithm lives in upstream’s
common/speculative.cpp (common_speculative_impl_draft_mtp). This module
wraps it through a stable C shim in llama-cpp-sys-4/mtp_shim/.
§Upstream behaviour (llama.cpp #23269+)
After MTP clean-up #23269:
- Draft sampling uses
top_k = 10inside upstream (not configurable from Rust). MtpSessionConfig::p_minfilters low-confidence draft tokens (default0.0).- Upstream CLI default for
n_maxis3; setMtpSessionConfig::n_draft_maxexplicitly — optimal values are model/quant dependent (MTP.mdon GitHub).
§Quick start
use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};
let n_draft_max = 3;
let target = model.new_context(&backend, LlamaContextParams::default())?;
let draft = model.new_context(
&backend,
LlamaContextParams::default()
.with_ctx_type(LlamaContextType::Mtp)
.with_n_rs_seq(n_draft_max.max(4)),
)?;
let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
let mut session = MtpSession::new_with_config(&target, &draft, config)?;§Speculative loop
For each generation step, after decoding on the target context:
// 1. Target prefill or verify decode (you build the batch)
target.decode(&mut batch)?;
// 2. Tell MTP about the batch just decoded on the target
session.process(&batch)?;
// 3. Ask for draft tokens starting from the last accepted token
let drafts = session.draft(0, n_past, last_token)?;
// 4. Verify drafts on the target (compare logits / sample — your code)
let n_accepted: u16 = /* ... */;
// 5. Sync draft recurrent state with what the target accepted
session.accept(0, n_accepted)?;Call MtpSession::begin once per fresh generation if you want upstream
prompt tracking (optional for MTP). Call MtpSession::print_stats when
finished to log draft/accept counters via llama.cpp’s log callback.
A full runnable implementation is in examples/mtp/.
§Embedding requirements
| Method | MTP typical value | Meaning |
|---|---|---|
MtpSession::need_embd_pre_norm | true | Next-n hidden states (upstream name) |
MtpSession::need_embd | false | Post-norm / seq embeddings not used |
Rust keeps *_pre_norm names; upstream C API uses *_nextn since the Jun 2026
llama.cpp bump. Session init configures extraction on both contexts automatically;
manual LlamaContext::set_embeddings_pre_norm is rarely needed.
Structs§
- MtpSession
- Owned MTP draft session.
- MtpSession
Config - Parameters for
MtpSession::new_with_config.
Enums§
- MtpSession
Error - Errors raised by the MTP draft session.