llama_cpp_4/mtp.rs
1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_impl_draft_mtp`). This module
11//! wraps it through a stable C shim in `llama-cpp-sys-4/mtp_shim/`.
12//!
13//! # Upstream behaviour (llama.cpp #23269+)
14//!
15//! After [MTP clean-up #23269](https://github.com/ggml-org/llama.cpp/pull/23269):
16//!
17//! - Draft sampling uses `top_k = 10` inside upstream (not configurable from Rust).
18//! - [`MtpSessionConfig::p_min`] filters low-confidence draft tokens (default `0.0`).
19//! - Upstream CLI default for `n_max` is `3`; set [`MtpSessionConfig::n_draft_max`]
20//! explicitly — optimal values are model/quant dependent ([`MTP.md`] on GitHub).
21//!
22//! [`MTP.md`]: https://github.com/eugenehp/llama-cpp-rs/blob/main/MTP.md
23//!
24//! # Quick start
25//!
26//! ```ignore
27//! use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
28//! use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};
29//!
30//! let n_draft_max = 3;
31//!
32//! let target = model.new_context(&backend, LlamaContextParams::default())?;
33//! let draft = model.new_context(
34//! &backend,
35//! LlamaContextParams::default()
36//! .with_ctx_type(LlamaContextType::Mtp)
37//! .with_n_rs_seq(n_draft_max.max(4)),
38//! )?;
39//!
40//! let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
41//! let mut session = MtpSession::new_with_config(&target, &draft, config)?;
42//! ```
43//!
44//! # Speculative loop
45//!
46//! For each generation step, after decoding on the **target** context:
47//!
48//! ```ignore
49//! // 1. Target prefill or verify decode (you build the batch)
50//! target.decode(&mut batch)?;
51//!
52//! // 2. Tell MTP about the batch just decoded on the target
53//! session.process(&batch)?;
54//!
55//! // 3. Ask for draft tokens starting from the last accepted token
56//! let drafts = session.draft(0, n_past, last_token)?;
57//!
58//! // 4. Verify drafts on the target (compare logits / sample — your code)
59//! let n_accepted: u16 = /* ... */;
60//!
61//! // 5. Sync draft recurrent state with what the target accepted
62//! session.accept(0, n_accepted)?;
63//! ```
64//!
65//! Call [`MtpSession::begin`] once per fresh generation if you want upstream
66//! prompt tracking (optional for MTP). Call [`MtpSession::print_stats`] when
67//! finished to log draft/accept counters via llama.cpp's log callback.
68//!
69//! A full runnable implementation is in `examples/mtp/`.
70//!
71//! # Embedding requirements
72//!
73//! | Method | MTP typical value | Meaning |
74//! |---|---|---|
75//! | [`MtpSession::need_embd_pre_norm`] | `true` | Next-n hidden states (upstream name) |
76//! | [`MtpSession::need_embd`] | `false` | Post-norm / seq embeddings not used |
77//!
78//! Rust keeps `*_pre_norm` names; upstream C API uses `*_nextn` since the Jun 2026
79//! llama.cpp bump. Session init configures extraction on both contexts automatically;
80//! manual [`LlamaContext::set_embeddings_pre_norm`] is rarely needed.
81
82use std::ptr::NonNull;
83
84use crate::context::LlamaContext;
85use crate::llama_batch::LlamaBatch;
86use crate::token::LlamaToken;
87
88/// Errors raised by the MTP draft session.
89#[derive(Debug, thiserror::Error)]
90pub enum MtpSessionError {
91 /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
92 /// or one of the contexts is incompatible).
93 #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
94 Init,
95
96 /// `mtp_session_process` returned false.
97 #[error("mtp_session_process failed (see llama.cpp logs)")]
98 Process,
99
100 /// Caller passed a sequence id outside `[0, n_seq)`.
101 #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
102 BadSeqId {
103 /// the offending seq id
104 seq_id: i32,
105 /// configured number of sequences
106 n_seq: u32,
107 },
108
109 /// Invalid session configuration (e.g. `n_draft_max <= 0`).
110 #[error("invalid MTP session config: {0}")]
111 InvalidConfig(&'static str),
112}
113
114/// Parameters for [`MtpSession::new_with_config`].
115///
116/// Maps directly to upstream `common_params_speculative_draft`.
117///
118/// # Examples
119///
120/// ```ignore
121/// // Defaults: n_min = 0, p_min = 0.0 (aligned with upstream #23269+)
122/// let cfg = MtpSessionConfig::new(1, 3);
123///
124/// // Stricter drafts: skip tokens below 10% draft-model probability
125/// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
126/// ```
127#[derive(Debug, Clone, Copy, PartialEq)]
128pub struct MtpSessionConfig {
129 /// Number of concurrent sequences (usually `1`).
130 pub n_seq: u32,
131 /// Maximum tokens drafted per [`MtpSession::draft`] call (`n_max` upstream).
132 pub n_draft_max: i32,
133 /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
134 pub n_min: i32,
135 /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
136 pub p_min: f32,
137}
138
139impl MtpSessionConfig {
140 /// Build config with upstream-aligned defaults for `n_min` (`0`) and `p_min` (`0.0`).
141 ///
142 /// # Examples
143 ///
144 /// ```ignore
145 /// let cfg = MtpSessionConfig::new(1, 3); // one sequence, up to 3 draft tokens
146 /// ```
147 #[must_use]
148 pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
149 Self {
150 n_seq,
151 n_draft_max,
152 n_min: 0,
153 p_min: 0.0,
154 }
155 }
156
157 /// Set minimum draft tokens (`n_min` upstream).
158 #[must_use]
159 pub fn with_n_min(mut self, n_min: i32) -> Self {
160 self.n_min = n_min;
161 self
162 }
163
164 /// Set draft probability floor (`p_min` upstream).
165 ///
166 /// Draft tokens whose greedy probability falls below this value are dropped.
167 /// Upstream default is `0.0` after #23269 (was `0.75` in older builds).
168 ///
169 /// # Examples
170 ///
171 /// ```ignore
172 /// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
173 /// ```
174 #[must_use]
175 pub fn with_p_min(mut self, p_min: f32) -> Self {
176 self.p_min = p_min;
177 self
178 }
179}
180
181/// Owned MTP draft session.
182///
183/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
184/// it holds) when freed.
185///
186/// # Lifetime contract (manual)
187///
188/// The session holds raw pointers to both the target and draft
189/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
190/// drop them) for as long as the session exists.**
191pub struct MtpSession {
192 raw: NonNull<llama_cpp_sys_4::mtp_session>,
193 config: MtpSessionConfig,
194}
195
196// SAFETY: the underlying C++ session owns its own state and is not tied to
197// any TLS. Concurrent calls from multiple threads are NOT safe.
198unsafe impl Send for MtpSession {}
199
200impl MtpSession {
201 /// Construct an MTP draft session with upstream defaults for `n_min` and
202 /// `p_min`.
203 ///
204 /// Equivalent to `new_with_config(MtpSessionConfig::new(n_seq, n_draft_max))`.
205 ///
206 /// # Examples
207 ///
208 /// ```ignore
209 /// let mut session = MtpSession::new(&target, &draft, 1, 3)?;
210 /// ```
211 ///
212 /// # Errors
213 ///
214 /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
215 pub fn new(
216 target: &LlamaContext<'_>,
217 draft: &LlamaContext<'_>,
218 n_seq: u32,
219 n_draft_max: i32,
220 ) -> Result<Self, MtpSessionError> {
221 Self::new_with_config(target, draft, MtpSessionConfig::new(n_seq, n_draft_max))
222 }
223
224 /// Construct an MTP draft session with full speculative draft parameters.
225 ///
226 /// `target` must be a [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) context.
227 /// `draft` must be a [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) context from the same model,
228 /// with [`LlamaContextParams::with_n_rs_seq`](crate::context::params::LlamaContextParams::with_n_rs_seq)
229 /// `>= config.n_draft_max`.
230 ///
231 /// # Examples
232 ///
233 /// ```ignore
234 /// let config = MtpSessionConfig::new(1, 1)
235 /// .with_p_min(0.0); // match upstream default after #23269
236 /// let session = MtpSession::new_with_config(&target, &draft, config)?;
237 /// ```
238 ///
239 /// # Errors
240 ///
241 /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
242 pub fn new_with_config(
243 target: &LlamaContext<'_>,
244 draft: &LlamaContext<'_>,
245 config: MtpSessionConfig,
246 ) -> Result<Self, MtpSessionError> {
247 if config.n_seq == 0 {
248 return Err(MtpSessionError::InvalidConfig("n_seq must be > 0"));
249 }
250 if config.n_draft_max <= 0 {
251 return Err(MtpSessionError::InvalidConfig("n_draft_max must be > 0"));
252 }
253
254 let c_config = llama_cpp_sys_4::mtp_session_config {
255 n_seq: config.n_seq,
256 n_draft_max: config.n_draft_max,
257 n_min: config.n_min,
258 p_min: config.p_min,
259 spec_type: llama_cpp_sys_4::MTP_SPEC_TYPE_MTP as i32,
260 };
261
262 let raw = unsafe {
263 llama_cpp_sys_4::mtp_session_new(
264 target.context.as_ptr(),
265 draft.context.as_ptr(),
266 &raw const c_config,
267 )
268 };
269 let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
270 Ok(Self { raw, config })
271 }
272
273 /// Session configuration passed at construction.
274 #[must_use]
275 pub fn config(&self) -> MtpSessionConfig {
276 self.config
277 }
278
279 /// True when the speculative backend needs post-norm embeddings on the
280 /// target context (`llama_set_embeddings`).
281 ///
282 /// MTP returns **false**; use [`Self::need_embd_pre_norm`] for MTP.
283 #[must_use]
284 pub fn need_embd(&self) -> bool {
285 unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
286 }
287
288 /// True when the speculative backend needs pre-norm hidden states on the
289 /// target context (`llama_set_embeddings_pre_norm`).
290 ///
291 /// MTP returns **true**. Upstream configures this on both contexts during
292 /// session init; callers normally do not need to set it manually.
293 #[must_use]
294 pub fn need_embd_pre_norm(&self) -> bool {
295 unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
296 }
297
298 /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
299 /// call.
300 #[must_use]
301 pub fn n_draft_max(&self) -> i32 {
302 self.config.n_draft_max
303 }
304
305 /// Configured minimum draft tokens (`n_min`).
306 #[must_use]
307 pub fn n_min(&self) -> i32 {
308 self.config.n_min
309 }
310
311 /// Configured draft probability floor (`p_min`).
312 #[must_use]
313 pub fn p_min(&self) -> f32 {
314 self.config.p_min
315 }
316
317 /// Configured number of sequences.
318 #[must_use]
319 pub fn n_seq(&self) -> u32 {
320 self.config.n_seq
321 }
322
323 /// Log speculative-decoding statistics (draft/accept counts and timings) via
324 /// llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`] to
325 /// capture output.
326 ///
327 /// # Examples
328 ///
329 /// ```ignore
330 /// // After your generation loop:
331 /// session.print_stats();
332 /// ```
333 pub fn print_stats(&self) {
334 unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
335 }
336
337 /// Optional: call once at the start of a fresh generation with the
338 /// prompt tokens that were just decoded into the target context.
339 ///
340 /// Upstream uses this for prompt tracking; MTP speculative loops often
341 /// work without it if you call [`Self::process`] after every target decode.
342 ///
343 /// # Examples
344 ///
345 /// ```ignore
346 /// session.begin(0, &prompt_tokens)?;
347 /// ```
348 pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
349 self.check_seq(seq_id)?;
350 unsafe {
351 llama_cpp_sys_4::mtp_session_begin(
352 self.raw.as_ptr(),
353 seq_id,
354 prompt.as_ptr().cast(),
355 prompt.len(),
356 );
357 }
358 Ok(())
359 }
360
361 /// Hand the session a batch that was just decoded on the target context.
362 ///
363 /// Call this after every successful `target.decode(batch)` so upstream can
364 /// sync draft recurrent state with the target KV cache.
365 ///
366 /// # Examples
367 ///
368 /// ```ignore
369 /// target.decode(&mut batch)?;
370 /// session.process(&batch)?;
371 /// ```
372 pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
373 let ok =
374 unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
375 if ok {
376 Ok(())
377 } else {
378 Err(MtpSessionError::Process)
379 }
380 }
381
382 /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
383 ///
384 /// `n_past` is the number of tokens already in the target KV cache for
385 /// `seq_id`. `id_last` is the last token accepted on the target (usually
386 /// the token you just sampled).
387 ///
388 /// # Examples
389 ///
390 /// ```ignore
391 /// let drafts = session.draft(0, n_past, last_token)?;
392 /// for draft in &drafts {
393 /// // verify each draft against target logits ...
394 /// }
395 /// ```
396 pub fn draft(
397 &mut self,
398 seq_id: i32,
399 n_past: i32,
400 id_last: LlamaToken,
401 ) -> Result<Vec<LlamaToken>, MtpSessionError> {
402 self.check_seq(seq_id)?;
403
404 let cap = self.config.n_draft_max.max(0) as usize;
405 let mut buf: Vec<i32> = vec![0; cap];
406 let mut out_n: i32 = cap as i32;
407
408 unsafe {
409 llama_cpp_sys_4::mtp_session_draft(
410 self.raw.as_ptr(),
411 seq_id,
412 n_past,
413 id_last.0,
414 buf.as_mut_ptr(),
415 &mut out_n,
416 );
417 }
418
419 let n = out_n.max(0) as usize;
420 buf.truncate(n);
421 Ok(buf.into_iter().map(LlamaToken).collect())
422 }
423
424 /// Inform the session how many draft tokens the target verifier accepted.
425 ///
426 /// Pass `0` when every draft was rejected. Upstream rolls back draft
427 /// recurrent state accordingly.
428 ///
429 /// # Examples
430 ///
431 /// ```ignore
432 /// session.accept(0, n_accepted)?;
433 /// ```
434 pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
435 self.check_seq(seq_id)?;
436 unsafe {
437 llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
438 }
439 Ok(())
440 }
441
442 fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
443 if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
444 return Err(MtpSessionError::BadSeqId {
445 seq_id,
446 n_seq: self.config.n_seq,
447 });
448 }
449 Ok(())
450 }
451}
452
453impl Drop for MtpSession {
454 fn drop(&mut self) {
455 unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
456 }
457}
458
459impl std::fmt::Debug for MtpSession {
460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461 f.debug_struct("MtpSession")
462 .field("config", &self.config)
463 .field("need_embd_pre_norm", &self.need_embd_pre_norm())
464 .finish()
465 }
466}