llama_cpp_4/mtp.rs
1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_impl_draft_mtp`). This module
11//! wraps it through a stable C shim in `llama-cpp-sys-4/mtp_shim/`.
12//!
13//! # Upstream behaviour (llama.cpp #23269+)
14//!
15//! After [MTP clean-up #23269](https://github.com/ggml-org/llama.cpp/pull/23269):
16//!
17//! - Draft sampling uses `top_k = 10` inside upstream (not configurable from Rust).
18//! - [`MtpSessionConfig::p_min`] filters low-confidence draft tokens (default `0.0`).
19//! - Upstream CLI default for `n_max` is `3`; set [`MtpSessionConfig::n_draft_max`]
20//! explicitly — optimal values are model/quant dependent ([`MTP.md`] on GitHub).
21//!
22//! [`MTP.md`]: https://github.com/eugenehp/llama-cpp-rs/blob/main/MTP.md
23//!
24//! # Quick start
25//!
26//! ```ignore
27//! use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
28//! use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};
29//!
30//! let n_draft_max = 3;
31//!
32//! let target = model.new_context(&backend, LlamaContextParams::default())?;
33//! let draft = model.new_context(
34//! &backend,
35//! LlamaContextParams::default()
36//! .with_ctx_type(LlamaContextType::Mtp)
37//! .with_n_rs_seq(n_draft_max.max(4)),
38//! )?;
39//!
40//! let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
41//! let mut session = MtpSession::new_with_config(&target, &draft, config)?;
42//! ```
43//!
44//! # Speculative loop
45//!
46//! For each generation step, after decoding on the **target** context:
47//!
48//! ```ignore
49//! // 1. Target prefill or verify decode (you build the batch)
50//! target.decode(&mut batch)?;
51//!
52//! // 2. Tell MTP about the batch just decoded on the target
53//! session.process(&batch)?;
54//!
55//! // 3. Ask for draft tokens starting from the last accepted token
56//! let drafts = session.draft(0, n_past, last_token)?;
57//!
58//! // 4. Verify drafts on the target (compare logits / sample — your code)
59//! let n_accepted: u16 = /* ... */;
60//!
61//! // 5. Sync draft recurrent state with what the target accepted
62//! session.accept(0, n_accepted)?;
63//! ```
64//!
65//! Call [`MtpSession::begin`] once per fresh generation if you want upstream
66//! prompt tracking (optional for MTP). Call [`MtpSession::print_stats`] when
67//! finished to log draft/accept counters via llama.cpp's log callback.
68//!
69//! A full runnable implementation is in `examples/mtp/`.
70//!
71//! # Embedding requirements
72//!
73//! | Method | MTP typical value | Meaning |
74//! |---|---|---|
75//! | [`MtpSession::need_embd_pre_norm`] | `true` | Next-n hidden states (upstream name) |
76//! | [`MtpSession::need_embd`] | `false` | Post-norm / seq embeddings not used |
77//!
78//! Rust keeps `*_pre_norm` names; upstream C API uses `*_nextn` since the Jun 2026
79//! llama.cpp bump. Session init configures extraction on both contexts automatically;
80//! manual [`LlamaContext::set_embeddings_pre_norm`] is rarely needed.
81
82use std::ptr::NonNull;
83
84use crate::context::LlamaContext;
85use crate::llama_batch::LlamaBatch;
86use crate::token::LlamaToken;
87
88/// Errors raised by the MTP draft session.
89#[derive(Debug, thiserror::Error)]
90pub enum MtpSessionError {
91 /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
92 /// or one of the contexts is incompatible).
93 #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
94 Init,
95
96 /// `mtp_session_process` returned false.
97 #[error("mtp_session_process failed (see llama.cpp logs)")]
98 Process,
99
100 /// Caller passed a sequence id outside `[0, n_seq)`.
101 #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
102 BadSeqId {
103 /// the offending seq id
104 seq_id: i32,
105 /// configured number of sequences
106 n_seq: u32,
107 },
108
109 /// Invalid session configuration (e.g. `n_draft_max <= 0`).
110 #[error("invalid MTP session config: {0}")]
111 InvalidConfig(&'static str),
112}
113
114/// Parameters for [`MtpSession::new_with_config`].
115///
116/// Maps directly to upstream `common_params_speculative_draft`.
117///
118/// # Examples
119///
120/// ```ignore
121/// // Defaults: n_min = 0, p_min = 0.0 (aligned with upstream #23269+)
122/// let cfg = MtpSessionConfig::new(1, 3);
123///
124/// // Stricter drafts: skip tokens below 10% draft-model probability
125/// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
126/// ```
127#[derive(Debug, Clone, Copy, PartialEq)]
128pub struct MtpSessionConfig {
129 /// Number of concurrent sequences (usually `1`).
130 pub n_seq: u32,
131 /// Maximum tokens drafted per [`MtpSession::draft`] call (`n_max` upstream).
132 pub n_draft_max: i32,
133 /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
134 pub n_min: i32,
135 /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
136 pub p_min: f32,
137}
138
139impl MtpSessionConfig {
140 /// Build config with upstream-aligned defaults for `n_min` (`0`) and `p_min` (`0.0`).
141 ///
142 /// # Examples
143 ///
144 /// ```ignore
145 /// let cfg = MtpSessionConfig::new(1, 3); // one sequence, up to 3 draft tokens
146 /// ```
147 #[must_use]
148 pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
149 Self {
150 n_seq,
151 n_draft_max,
152 n_min: 0,
153 p_min: 0.0,
154 }
155 }
156
157 /// Set minimum draft tokens (`n_min` upstream).
158 #[must_use]
159 pub fn with_n_min(mut self, n_min: i32) -> Self {
160 self.n_min = n_min;
161 self
162 }
163
164 /// Set draft probability floor (`p_min` upstream).
165 ///
166 /// Draft tokens whose greedy probability falls below this value are dropped.
167 /// Upstream default is `0.0` after #23269 (was `0.75` in older builds).
168 ///
169 /// # Examples
170 ///
171 /// ```ignore
172 /// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
173 /// ```
174 #[must_use]
175 pub fn with_p_min(mut self, p_min: f32) -> Self {
176 self.p_min = p_min;
177 self
178 }
179}
180
181/// Owned MTP draft session.
182///
183/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
184/// it holds) when freed.
185///
186/// # Lifetime contract (manual)
187///
188/// The session holds raw pointers to both the target and draft
189/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
190/// drop them) for as long as the session exists.**
191pub struct MtpSession {
192 raw: NonNull<llama_cpp_sys_4::mtp_session>,
193 config: MtpSessionConfig,
194}
195
196// SAFETY: the underlying C++ session owns its own state and is not tied to
197// any TLS. Concurrent calls from multiple threads are NOT safe.
198unsafe impl Send for MtpSession {}
199
200impl MtpSession {
201 /// Construct an MTP draft session with upstream defaults for `n_min` and
202 /// `p_min`.
203 ///
204 /// Equivalent to `new_with_config(MtpSessionConfig::new(n_seq, n_draft_max))`.
205 ///
206 /// # Examples
207 ///
208 /// ```ignore
209 /// let mut session = MtpSession::new(&target, &draft, 1, 3)?;
210 /// ```
211 ///
212 /// # Errors
213 ///
214 /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
215 pub fn new(
216 target: &LlamaContext<'_>,
217 draft: &LlamaContext<'_>,
218 n_seq: u32,
219 n_draft_max: i32,
220 ) -> Result<Self, MtpSessionError> {
221 Self::new_with_config(target, draft, MtpSessionConfig::new(n_seq, n_draft_max))
222 }
223
224 /// Construct an MTP draft session with full speculative draft parameters.
225 ///
226 /// `target` must be a [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) context.
227 /// `draft` must be a [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) context from the same model,
228 /// with [`LlamaContextParams::with_n_rs_seq`](crate::context::params::LlamaContextParams::with_n_rs_seq)
229 /// `>= config.n_draft_max`.
230 ///
231 /// # Examples
232 ///
233 /// ```ignore
234 /// let config = MtpSessionConfig::new(1, 1)
235 /// .with_p_min(0.0); // match upstream default after #23269
236 /// let session = MtpSession::new_with_config(&target, &draft, config)?;
237 /// ```
238 ///
239 /// # Errors
240 ///
241 /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
242 pub fn new_with_config(
243 target: &LlamaContext<'_>,
244 draft: &LlamaContext<'_>,
245 config: MtpSessionConfig,
246 ) -> Result<Self, MtpSessionError> {
247 if config.n_seq == 0 {
248 return Err(MtpSessionError::InvalidConfig("n_seq must be > 0"));
249 }
250 if config.n_draft_max <= 0 {
251 return Err(MtpSessionError::InvalidConfig("n_draft_max must be > 0"));
252 }
253
254 let c_config = llama_cpp_sys_4::mtp_session_config {
255 n_seq: config.n_seq,
256 n_draft_max: config.n_draft_max,
257 n_min: config.n_min,
258 p_min: config.p_min,
259 };
260
261 let raw = unsafe {
262 llama_cpp_sys_4::mtp_session_new(
263 target.context.as_ptr(),
264 draft.context.as_ptr(),
265 &raw const c_config,
266 )
267 };
268 let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
269 Ok(Self { raw, config })
270 }
271
272 /// Session configuration passed at construction.
273 #[must_use]
274 pub fn config(&self) -> MtpSessionConfig {
275 self.config
276 }
277
278 /// True when the speculative backend needs post-norm embeddings on the
279 /// target context (`llama_set_embeddings`).
280 ///
281 /// MTP returns **false**; use [`Self::need_embd_pre_norm`] for MTP.
282 #[must_use]
283 pub fn need_embd(&self) -> bool {
284 unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
285 }
286
287 /// True when the speculative backend needs pre-norm hidden states on the
288 /// target context (`llama_set_embeddings_pre_norm`).
289 ///
290 /// MTP returns **true**. Upstream configures this on both contexts during
291 /// session init; callers normally do not need to set it manually.
292 #[must_use]
293 pub fn need_embd_pre_norm(&self) -> bool {
294 unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
295 }
296
297 /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
298 /// call.
299 #[must_use]
300 pub fn n_draft_max(&self) -> i32 {
301 self.config.n_draft_max
302 }
303
304 /// Configured minimum draft tokens (`n_min`).
305 #[must_use]
306 pub fn n_min(&self) -> i32 {
307 self.config.n_min
308 }
309
310 /// Configured draft probability floor (`p_min`).
311 #[must_use]
312 pub fn p_min(&self) -> f32 {
313 self.config.p_min
314 }
315
316 /// Configured number of sequences.
317 #[must_use]
318 pub fn n_seq(&self) -> u32 {
319 self.config.n_seq
320 }
321
322 /// Log speculative-decoding statistics (draft/accept counts and timings) via
323 /// llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`] to
324 /// capture output.
325 ///
326 /// # Examples
327 ///
328 /// ```ignore
329 /// // After your generation loop:
330 /// session.print_stats();
331 /// ```
332 pub fn print_stats(&self) {
333 unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
334 }
335
336 /// Optional: call once at the start of a fresh generation with the
337 /// prompt tokens that were just decoded into the target context.
338 ///
339 /// Upstream uses this for prompt tracking; MTP speculative loops often
340 /// work without it if you call [`Self::process`] after every target decode.
341 ///
342 /// # Examples
343 ///
344 /// ```ignore
345 /// session.begin(0, &prompt_tokens)?;
346 /// ```
347 pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
348 self.check_seq(seq_id)?;
349 unsafe {
350 llama_cpp_sys_4::mtp_session_begin(
351 self.raw.as_ptr(),
352 seq_id,
353 prompt.as_ptr().cast(),
354 prompt.len(),
355 );
356 }
357 Ok(())
358 }
359
360 /// Hand the session a batch that was just decoded on the target context.
361 ///
362 /// Call this after every successful `target.decode(batch)` so upstream can
363 /// sync draft recurrent state with the target KV cache.
364 ///
365 /// # Examples
366 ///
367 /// ```ignore
368 /// target.decode(&mut batch)?;
369 /// session.process(&batch)?;
370 /// ```
371 pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
372 let ok =
373 unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
374 if ok {
375 Ok(())
376 } else {
377 Err(MtpSessionError::Process)
378 }
379 }
380
381 /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
382 ///
383 /// `n_past` is the number of tokens already in the target KV cache for
384 /// `seq_id`. `id_last` is the last token accepted on the target (usually
385 /// the token you just sampled).
386 ///
387 /// # Examples
388 ///
389 /// ```ignore
390 /// let drafts = session.draft(0, n_past, last_token)?;
391 /// for draft in &drafts {
392 /// // verify each draft against target logits ...
393 /// }
394 /// ```
395 pub fn draft(
396 &mut self,
397 seq_id: i32,
398 n_past: i32,
399 id_last: LlamaToken,
400 ) -> Result<Vec<LlamaToken>, MtpSessionError> {
401 self.check_seq(seq_id)?;
402
403 let cap = self.config.n_draft_max.max(0) as usize;
404 let mut buf: Vec<i32> = vec![0; cap];
405 let mut out_n: i32 = cap as i32;
406
407 unsafe {
408 llama_cpp_sys_4::mtp_session_draft(
409 self.raw.as_ptr(),
410 seq_id,
411 n_past,
412 id_last.0,
413 buf.as_mut_ptr(),
414 &mut out_n,
415 );
416 }
417
418 let n = out_n.max(0) as usize;
419 buf.truncate(n);
420 Ok(buf.into_iter().map(LlamaToken).collect())
421 }
422
423 /// Inform the session how many draft tokens the target verifier accepted.
424 ///
425 /// Pass `0` when every draft was rejected. Upstream rolls back draft
426 /// recurrent state accordingly.
427 ///
428 /// # Examples
429 ///
430 /// ```ignore
431 /// session.accept(0, n_accepted)?;
432 /// ```
433 pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
434 self.check_seq(seq_id)?;
435 unsafe {
436 llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
437 }
438 Ok(())
439 }
440
441 fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
442 if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
443 return Err(MtpSessionError::BadSeqId {
444 seq_id,
445 n_seq: self.config.n_seq,
446 });
447 }
448 Ok(())
449 }
450}
451
452impl Drop for MtpSession {
453 fn drop(&mut self) {
454 unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
455 }
456}
457
458impl std::fmt::Debug for MtpSession {
459 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460 f.debug_struct("MtpSession")
461 .field("config", &self.config)
462 .field("need_embd_pre_norm", &self.need_embd_pre_norm())
463 .finish()
464 }
465}