Skip to main content

Module mtp

Module mtp 

Source
Expand description

Safe wrapper around the C++ MTP draft session.

MtpSession pairs a target LlamaContext with an MTP draft LlamaContext (built with crate::context::params::LlamaContextType::Mtp) and drives the multi-token-prediction speculative-decoding loop introduced in upstream llama.cpp PR #22673.

The draft algorithm lives in upstream’s common/speculative.cpp (common_speculative_impl_draft_mtp). This module wraps it through a stable C shim in llama-cpp-sys-4/mtp_shim/.

§Upstream behaviour (llama.cpp #23269+)

After MTP clean-up #23269:

  • Draft sampling uses top_k = 10 inside upstream (not configurable from Rust).
  • MtpSessionConfig::p_min filters low-confidence draft tokens (default 0.0).
  • Upstream CLI default for n_max is 3; set MtpSessionConfig::n_draft_max explicitly — optimal values are model/quant dependent (MTP.md on GitHub).

§Quick start

use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};

let n_draft_max = 3;

let target = model.new_context(&backend, LlamaContextParams::default())?;
let draft = model.new_context(
    &backend,
    LlamaContextParams::default()
        .with_ctx_type(LlamaContextType::Mtp)
        .with_n_rs_seq(n_draft_max.max(4)),
)?;

let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
let mut session = MtpSession::new_with_config(&target, &draft, config)?;

§Speculative loop

For each generation step, after decoding on the target context:

// 1. Target prefill or verify decode (you build the batch)
target.decode(&mut batch)?;

// 2. Tell MTP about the batch just decoded on the target
session.process(&batch)?;

// 3. Ask for draft tokens starting from the last accepted token
let drafts = session.draft(0, n_past, last_token)?;

// 4. Verify drafts on the target (compare logits / sample — your code)
let n_accepted: u16 = /* ... */;

// 5. Sync draft recurrent state with what the target accepted
session.accept(0, n_accepted)?;

Call MtpSession::begin once per fresh generation if you want upstream prompt tracking (optional for MTP). Call MtpSession::print_stats when finished to log draft/accept counters via llama.cpp’s log callback.

A full runnable implementation is in examples/mtp/.

§Embedding requirements

MethodMTP typical valueMeaning
MtpSession::need_embd_pre_normtrueNext-n hidden states (upstream name)
MtpSession::need_embdfalsePost-norm / seq embeddings not used

Rust keeps *_pre_norm names; upstream C API uses *_nextn since the Jun 2026 llama.cpp bump. Session init configures extraction on both contexts automatically; manual LlamaContext::set_embeddings_pre_norm is rarely needed.

Structs§

MtpSession
Owned MTP draft session.
MtpSessionConfig
Parameters for MtpSession::new_with_config.

Enums§

MtpSessionError
Errors raised by the MTP draft session.