llama_cpp_4/mtp.rs
1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_impl_draft_mtp`). This module
11//! wraps it through a stable C shim in `llama-cpp-sys-4/mtp_shim/`.
12//!
13//! # Upstream behaviour (llama.cpp #23269+)
14//!
15//! After [MTP clean-up #23269](https://github.com/ggml-org/llama.cpp/pull/23269):
16//!
17//! - Draft sampling uses `top_k = 10` inside upstream (not configurable from Rust).
18//! - [`MtpSessionConfig::p_min`] filters low-confidence draft tokens (default `0.0`).
19//! - Upstream CLI default for `n_max` is `3`; set [`MtpSessionConfig::n_draft_max`]
20//! explicitly — optimal values are model/quant dependent ([`MTP.md`] on GitHub).
21//!
22//! [`MTP.md`]: https://github.com/eugenehp/llama-cpp-rs/blob/main/MTP.md
23//!
24//! # Quick start
25//!
26//! ```ignore
27//! use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
28//! use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};
29//!
30//! let n_draft_max = 3;
31//!
32//! let target = model.new_context(&backend, LlamaContextParams::default())?;
33//! let draft = model.new_context(
34//! &backend,
35//! LlamaContextParams::default()
36//! .with_ctx_type(LlamaContextType::Mtp)
37//! .with_n_rs_seq(n_draft_max.max(4)),
38//! )?;
39//!
40//! let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
41//! let mut session = MtpSession::new_with_config(&target, &draft, config)?;
42//! ```
43//!
44//! # Speculative loop
45//!
46//! For each generation step, after decoding on the **target** context:
47//!
48//! ```ignore
49//! // 1. Target prefill or verify decode (you build the batch)
50//! target.decode(&mut batch)?;
51//!
52//! // 2. Tell MTP about the batch just decoded on the target
53//! session.process(&batch)?;
54//!
55//! // 3. Ask for draft tokens starting from the last accepted token
56//! let drafts = session.draft(0, n_past, last_token)?;
57//!
58//! // 4. Verify drafts on the target (compare logits / sample — your code)
59//! let n_accepted: u16 = /* ... */;
60//!
61//! // 5. Sync draft recurrent state with what the target accepted
62//! session.accept(0, n_accepted)?;
63//! ```
64//!
65//! Call [`MtpSession::begin`] once per fresh generation if you want upstream
66//! prompt tracking (optional for MTP). Call [`MtpSession::print_stats`] when
67//! finished to log draft/accept counters via llama.cpp's log callback.
68//!
69//! A full runnable implementation is in `examples/mtp/`.
70//!
71//! # Embedding requirements
72//!
73//! | Method | MTP typical value | Meaning |
74//! |---|---|---|
75//! | [`MtpSession::need_embd_pre_norm`] | `true` | Next-n hidden states (upstream name) |
76//! | [`MtpSession::need_embd`] | `false` | Post-norm / seq embeddings not used |
77//!
78//! # Multi-head `NextN` (Step3.5+)
79//!
80//! When [`crate::model::LlamaModel::n_layer_nextn`] returns a value greater than `1`, set the
81//! draft context head before each [`MtpSession::draft`] call:
82//!
83//! ```ignore
84//! for head in 0..model.n_layer_nextn() {
85//! draft.set_nextn_layer_offset(head);
86//! let drafts = session.draft(0, n_past, last_token)?;
87//! // verify on target ...
88//! }
89//! draft.set_nextn_layer_offset(0); // restore default
90//! ```
91//!
92
93use std::ptr::NonNull;
94
95use crate::context::LlamaContext;
96use crate::llama_batch::LlamaBatch;
97use crate::token::LlamaToken;
98
99/// Errors raised by the MTP draft session.
100#[derive(Debug, thiserror::Error)]
101pub enum MtpSessionError {
102 /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
103 /// or one of the contexts is incompatible).
104 #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
105 Init,
106
107 /// `mtp_session_process` returned false.
108 #[error("mtp_session_process failed (see llama.cpp logs)")]
109 Process,
110
111 /// Caller passed a sequence id outside `[0, n_seq)`.
112 #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
113 BadSeqId {
114 /// the offending seq id
115 seq_id: i32,
116 /// configured number of sequences
117 n_seq: u32,
118 },
119
120 /// Invalid session configuration (e.g. `n_draft_max <= 0`).
121 #[error("invalid MTP session config: {0}")]
122 InvalidConfig(&'static str),
123}
124
125/// Parameters for [`MtpSession::new_with_config`].
126///
127/// Maps directly to upstream `common_params_speculative_draft`.
128///
129/// # Examples
130///
131/// ```ignore
132/// // Defaults: n_min = 0, p_min = 0.0 (aligned with upstream #23269+)
133/// let cfg = MtpSessionConfig::new(1, 3);
134///
135/// // Stricter drafts: skip tokens below 10% draft-model probability
136/// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
137/// ```
138#[derive(Debug, Clone, Copy, PartialEq)]
139pub struct MtpSessionConfig {
140 /// Number of concurrent sequences (usually `1`).
141 pub n_seq: u32,
142 /// Maximum tokens drafted per [`MtpSession::draft`] call (`n_max` upstream).
143 pub n_draft_max: i32,
144 /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
145 pub n_min: i32,
146 /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
147 pub p_min: f32,
148}
149
150impl MtpSessionConfig {
151 /// Build config with upstream-aligned defaults for `n_min` (`0`) and `p_min` (`0.0`).
152 ///
153 /// # Examples
154 ///
155 /// ```ignore
156 /// let cfg = MtpSessionConfig::new(1, 3); // one sequence, up to 3 draft tokens
157 /// ```
158 #[must_use]
159 pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
160 Self {
161 n_seq,
162 n_draft_max,
163 n_min: 0,
164 p_min: 0.0,
165 }
166 }
167
168 /// Set minimum draft tokens (`n_min` upstream).
169 #[must_use]
170 pub fn with_n_min(mut self, n_min: i32) -> Self {
171 self.n_min = n_min;
172 self
173 }
174
175 /// Set draft probability floor (`p_min` upstream).
176 ///
177 /// Draft tokens whose greedy probability falls below this value are dropped.
178 /// Upstream default is `0.0` after #23269 (was `0.75` in older builds).
179 ///
180 /// # Examples
181 ///
182 /// ```ignore
183 /// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
184 /// ```
185 #[must_use]
186 pub fn with_p_min(mut self, p_min: f32) -> Self {
187 self.p_min = p_min;
188 self
189 }
190}
191
192/// Owned MTP draft session.
193///
194/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
195/// it holds) when freed.
196///
197/// # Lifetime contract (manual)
198///
199/// The session holds raw pointers to both the target and draft
200/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
201/// drop them) for as long as the session exists.**
202pub struct MtpSession {
203 raw: NonNull<llama_cpp_sys_4::mtp_session>,
204 config: MtpSessionConfig,
205}
206
207// SAFETY: the underlying C++ session owns its own state and is not tied to
208// any TLS. Concurrent calls from multiple threads are NOT safe.
209unsafe impl Send for MtpSession {}
210
211impl MtpSession {
212 /// Construct an MTP draft session with upstream defaults for `n_min` and
213 /// `p_min`.
214 ///
215 /// Equivalent to `new_with_config(MtpSessionConfig::new(n_seq, n_draft_max))`.
216 ///
217 /// # Examples
218 ///
219 /// ```ignore
220 /// let mut session = MtpSession::new(&target, &draft, 1, 3)?;
221 /// ```
222 ///
223 /// # Errors
224 ///
225 /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
226 pub fn new(
227 target: &LlamaContext<'_>,
228 draft: &LlamaContext<'_>,
229 n_seq: u32,
230 n_draft_max: i32,
231 ) -> Result<Self, MtpSessionError> {
232 Self::new_with_config(target, draft, MtpSessionConfig::new(n_seq, n_draft_max))
233 }
234
235 /// Construct an MTP draft session with full speculative draft parameters.
236 ///
237 /// `target` must be a [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) context.
238 /// `draft` must be a [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) context from the same model,
239 /// with [`LlamaContextParams::with_n_rs_seq`](crate::context::params::LlamaContextParams::with_n_rs_seq)
240 /// `>= config.n_draft_max`.
241 ///
242 /// # Examples
243 ///
244 /// ```ignore
245 /// let config = MtpSessionConfig::new(1, 1)
246 /// .with_p_min(0.0); // match upstream default after #23269
247 /// let session = MtpSession::new_with_config(&target, &draft, config)?;
248 /// ```
249 ///
250 /// # Errors
251 ///
252 /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
253 pub fn new_with_config(
254 target: &LlamaContext<'_>,
255 draft: &LlamaContext<'_>,
256 config: MtpSessionConfig,
257 ) -> Result<Self, MtpSessionError> {
258 if config.n_seq == 0 {
259 return Err(MtpSessionError::InvalidConfig("n_seq must be > 0"));
260 }
261 if config.n_draft_max <= 0 {
262 return Err(MtpSessionError::InvalidConfig("n_draft_max must be > 0"));
263 }
264
265 let c_config = llama_cpp_sys_4::mtp_session_config {
266 n_seq: config.n_seq,
267 n_draft_max: config.n_draft_max,
268 n_min: config.n_min,
269 p_min: config.p_min,
270 spec_type: llama_cpp_sys_4::MTP_SPEC_TYPE_MTP.cast_signed(),
271 };
272
273 let raw = unsafe {
274 llama_cpp_sys_4::mtp_session_new(
275 target.context.as_ptr(),
276 draft.context.as_ptr(),
277 &raw const c_config,
278 )
279 };
280 let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
281 Ok(Self { raw, config })
282 }
283
284 /// Session configuration passed at construction.
285 #[must_use]
286 pub fn config(&self) -> MtpSessionConfig {
287 self.config
288 }
289
290 /// True when the speculative backend needs post-norm embeddings on the
291 /// target context (`llama_set_embeddings`).
292 ///
293 /// MTP returns **false**; use [`Self::need_embd_pre_norm`] for MTP.
294 #[must_use]
295 pub fn need_embd(&self) -> bool {
296 unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
297 }
298
299 /// True when the speculative backend needs pre-norm hidden states on the
300 /// target context (`llama_set_embeddings_pre_norm`).
301 ///
302 /// MTP returns **true**. Upstream configures this on both contexts during
303 /// session init; callers normally do not need to set it manually.
304 #[must_use]
305 pub fn need_embd_pre_norm(&self) -> bool {
306 unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
307 }
308
309 /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
310 /// call.
311 #[must_use]
312 pub fn n_draft_max(&self) -> i32 {
313 self.config.n_draft_max
314 }
315
316 /// Configured minimum draft tokens (`n_min`).
317 #[must_use]
318 pub fn n_min(&self) -> i32 {
319 self.config.n_min
320 }
321
322 /// Configured draft probability floor (`p_min`).
323 #[must_use]
324 pub fn p_min(&self) -> f32 {
325 self.config.p_min
326 }
327
328 /// Configured number of sequences.
329 #[must_use]
330 pub fn n_seq(&self) -> u32 {
331 self.config.n_seq
332 }
333
334 /// Log speculative-decoding statistics (draft/accept counts and timings) via
335 /// llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`] to
336 /// capture output.
337 ///
338 /// # Examples
339 ///
340 /// ```ignore
341 /// // After your generation loop:
342 /// session.print_stats();
343 /// ```
344 pub fn print_stats(&self) {
345 unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
346 }
347
348 /// Optional: call once at the start of a fresh generation with the
349 /// prompt tokens that were just decoded into the target context.
350 ///
351 /// Upstream uses this for prompt tracking; MTP speculative loops often
352 /// work without it if you call [`Self::process`] after every target decode.
353 ///
354 /// # Examples
355 ///
356 /// ```ignore
357 /// session.begin(0, &prompt_tokens)?;
358 /// ```
359 ///
360 /// # Errors
361 ///
362 /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is out of range.
363 pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
364 self.check_seq(seq_id)?;
365 unsafe {
366 llama_cpp_sys_4::mtp_session_begin(
367 self.raw.as_ptr(),
368 seq_id,
369 prompt.as_ptr().cast(),
370 prompt.len(),
371 );
372 }
373 Ok(())
374 }
375
376 /// Hand the session a batch that was just decoded on the target context.
377 ///
378 /// Call this after every successful `target.decode(batch)` so upstream can
379 /// sync draft recurrent state with the target KV cache.
380 ///
381 /// # Examples
382 ///
383 /// ```ignore
384 /// target.decode(&mut batch)?;
385 /// session.process(&batch)?;
386 /// ```
387 ///
388 /// # Errors
389 ///
390 /// Returns [`MtpSessionError::Process`] when upstream rejects the batch.
391 pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
392 let ok = unsafe {
393 llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &raw const batch.llama_batch)
394 };
395 if ok {
396 Ok(())
397 } else {
398 Err(MtpSessionError::Process)
399 }
400 }
401
402 /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
403 ///
404 /// `n_past` is the number of tokens already in the target KV cache for
405 /// `seq_id`. `id_last` is the last token accepted on the target (usually
406 /// the token you just sampled).
407 ///
408 /// # Examples
409 ///
410 /// ```ignore
411 /// let drafts = session.draft(0, n_past, last_token)?;
412 /// for draft in &drafts {
413 /// // verify each draft against target logits ...
414 /// }
415 /// ```
416 ///
417 /// # Errors
418 ///
419 /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is out of range.
420 pub fn draft(
421 &mut self,
422 seq_id: i32,
423 n_past: i32,
424 id_last: LlamaToken,
425 ) -> Result<Vec<LlamaToken>, MtpSessionError> {
426 self.check_seq(seq_id)?;
427
428 let cap = usize::try_from(self.config.n_draft_max.max(0)).unwrap_or(0);
429 let mut buf: Vec<i32> = vec![0; cap];
430 let mut out_n = i32::try_from(cap).unwrap_or(i32::MAX);
431
432 unsafe {
433 llama_cpp_sys_4::mtp_session_draft(
434 self.raw.as_ptr(),
435 seq_id,
436 n_past,
437 id_last.0,
438 buf.as_mut_ptr(),
439 &raw mut out_n,
440 );
441 }
442
443 let n = usize::try_from(out_n.max(0)).unwrap_or(0);
444 buf.truncate(n);
445 Ok(buf.into_iter().map(LlamaToken).collect())
446 }
447
448 /// Inform the session how many draft tokens the target verifier accepted.
449 ///
450 /// Pass `0` when every draft was rejected. Upstream rolls back draft
451 /// recurrent state accordingly.
452 ///
453 /// # Examples
454 ///
455 /// ```ignore
456 /// session.accept(0, n_accepted)?;
457 /// ```
458 ///
459 /// # Errors
460 ///
461 /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is out of range.
462 pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
463 self.check_seq(seq_id)?;
464 unsafe {
465 llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
466 }
467 Ok(())
468 }
469
470 fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
471 if seq_id < 0 || seq_id.cast_unsigned() >= self.config.n_seq {
472 return Err(MtpSessionError::BadSeqId {
473 seq_id,
474 n_seq: self.config.n_seq,
475 });
476 }
477 Ok(())
478 }
479}
480
481impl Drop for MtpSession {
482 fn drop(&mut self) {
483 unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
484 }
485}
486
487impl std::fmt::Debug for MtpSession {
488 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489 f.debug_struct("MtpSession")
490 .field("config", &self.config)
491 .field("need_embd_pre_norm", &self.need_embd_pre_norm())
492 .finish_non_exhaustive()
493 }
494}