Skip to main content

llama_cpp_4/
mtp.rs

1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_impl_draft_mtp`). This module
11//! wraps it through a stable C shim in `llama-cpp-sys-4/mtp_shim/`.
12//!
13//! # Upstream behaviour (llama.cpp #23269+)
14//!
15//! After [MTP clean-up #23269](https://github.com/ggml-org/llama.cpp/pull/23269):
16//!
17//! - Draft sampling uses `top_k = 10` inside upstream (not configurable from Rust).
18//! - [`MtpSessionConfig::p_min`] filters low-confidence draft tokens (default `0.0`).
19//! - Upstream CLI default for `n_max` is `3`; set [`MtpSessionConfig::n_draft_max`]
20//!   explicitly — optimal values are model/quant dependent ([`MTP.md`] on GitHub).
21//!
22//! [`MTP.md`]: https://github.com/eugenehp/llama-cpp-rs/blob/main/MTP.md
23//!
24//! # Quick start
25//!
26//! ```ignore
27//! use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
28//! use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};
29//!
30//! let n_draft_max = 3;
31//!
32//! let target = model.new_context(&backend, LlamaContextParams::default())?;
33//! let draft = model.new_context(
34//!     &backend,
35//!     LlamaContextParams::default()
36//!         .with_ctx_type(LlamaContextType::Mtp)
37//!         .with_n_rs_seq(n_draft_max.max(4)),
38//! )?;
39//!
40//! let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
41//! let mut session = MtpSession::new_with_config(&target, &draft, config)?;
42//! ```
43//!
44//! # Speculative loop
45//!
46//! For each generation step, after decoding on the **target** context:
47//!
48//! ```ignore
49//! // 1. Target prefill or verify decode (you build the batch)
50//! target.decode(&mut batch)?;
51//!
52//! // 2. Tell MTP about the batch just decoded on the target
53//! session.process(&batch)?;
54//!
55//! // 3. Ask for draft tokens starting from the last accepted token
56//! let drafts = session.draft(0, n_past, last_token)?;
57//!
58//! // 4. Verify drafts on the target (compare logits / sample — your code)
59//! let n_accepted: u16 = /* ... */;
60//!
61//! // 5. Sync draft recurrent state with what the target accepted
62//! session.accept(0, n_accepted)?;
63//! ```
64//!
65//! Call [`MtpSession::begin`] once per fresh generation if you want upstream
66//! prompt tracking (optional for MTP). Call [`MtpSession::print_stats`] when
67//! finished to log draft/accept counters via llama.cpp's log callback.
68//!
69//! A full runnable implementation is in `examples/mtp/`.
70//!
71//! # Embedding requirements
72//!
73//! | Method | MTP typical value | Meaning |
74//! |---|---|---|
75//! | [`MtpSession::need_embd_pre_norm`] | `true` | Next-n hidden states (upstream name) |
76//! | [`MtpSession::need_embd`] | `false` | Post-norm / seq embeddings not used |
77//!
78//! # Multi-head `NextN` (Step3.5+)
79//!
80//! When [`crate::model::LlamaModel::n_layer_nextn`] returns a value greater than `1`, set the
81//! draft context head before each [`MtpSession::draft`] call:
82//!
83//! ```ignore
84//! for head in 0..model.n_layer_nextn() {
85//!     draft.set_nextn_layer_offset(head);
86//!     let drafts = session.draft(0, n_past, last_token)?;
87//!     // verify on target ...
88//! }
89//! draft.set_nextn_layer_offset(0); // restore default
90//! ```
91//!
92
93use std::ptr::NonNull;
94
95use crate::context::LlamaContext;
96use crate::llama_batch::LlamaBatch;
97use crate::token::LlamaToken;
98
99/// Errors raised by the MTP draft session.
100#[derive(Debug, thiserror::Error)]
101pub enum MtpSessionError {
102    /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
103    /// or one of the contexts is incompatible).
104    #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
105    Init,
106
107    /// `mtp_session_process` returned false.
108    #[error("mtp_session_process failed (see llama.cpp logs)")]
109    Process,
110
111    /// Caller passed a sequence id outside `[0, n_seq)`.
112    #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
113    BadSeqId {
114        /// the offending seq id
115        seq_id: i32,
116        /// configured number of sequences
117        n_seq: u32,
118    },
119
120    /// Invalid session configuration (e.g. `n_draft_max <= 0`).
121    #[error("invalid MTP session config: {0}")]
122    InvalidConfig(&'static str),
123}
124
125/// Parameters for [`MtpSession::new_with_config`].
126///
127/// Maps directly to upstream `common_params_speculative_draft`.
128///
129/// # Examples
130///
131/// ```ignore
132/// // Defaults: n_min = 0, p_min = 0.0 (aligned with upstream #23269+)
133/// let cfg = MtpSessionConfig::new(1, 3);
134///
135/// // Stricter drafts: skip tokens below 10% draft-model probability
136/// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
137/// ```
138#[derive(Debug, Clone, Copy, PartialEq)]
139pub struct MtpSessionConfig {
140    /// Number of concurrent sequences (usually `1`).
141    pub n_seq: u32,
142    /// Maximum tokens drafted per [`MtpSession::draft`] call (`n_max` upstream).
143    pub n_draft_max: i32,
144    /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
145    pub n_min: i32,
146    /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
147    pub p_min: f32,
148}
149
150impl MtpSessionConfig {
151    /// Build config with upstream-aligned defaults for `n_min` (`0`) and `p_min` (`0.0`).
152    ///
153    /// # Examples
154    ///
155    /// ```ignore
156    /// let cfg = MtpSessionConfig::new(1, 3); // one sequence, up to 3 draft tokens
157    /// ```
158    #[must_use]
159    pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
160        Self {
161            n_seq,
162            n_draft_max,
163            n_min: 0,
164            p_min: 0.0,
165        }
166    }
167
168    /// Set minimum draft tokens (`n_min` upstream).
169    #[must_use]
170    pub fn with_n_min(mut self, n_min: i32) -> Self {
171        self.n_min = n_min;
172        self
173    }
174
175    /// Set draft probability floor (`p_min` upstream).
176    ///
177    /// Draft tokens whose greedy probability falls below this value are dropped.
178    /// Upstream default is `0.0` after #23269 (was `0.75` in older builds).
179    ///
180    /// # Examples
181    ///
182    /// ```ignore
183    /// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
184    /// ```
185    #[must_use]
186    pub fn with_p_min(mut self, p_min: f32) -> Self {
187        self.p_min = p_min;
188        self
189    }
190}
191
192/// Owned MTP draft session.
193///
194/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
195/// it holds) when freed.
196///
197/// # Lifetime contract (manual)
198///
199/// The session holds raw pointers to both the target and draft
200/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
201/// drop them) for as long as the session exists.**
202pub struct MtpSession {
203    raw: NonNull<llama_cpp_sys_4::mtp_session>,
204    config: MtpSessionConfig,
205}
206
207// SAFETY: the underlying C++ session owns its own state and is not tied to
208// any TLS. Concurrent calls from multiple threads are NOT safe.
209unsafe impl Send for MtpSession {}
210
211impl MtpSession {
212    /// Construct an MTP draft session with upstream defaults for `n_min` and
213    /// `p_min`.
214    ///
215    /// Equivalent to `new_with_config(MtpSessionConfig::new(n_seq, n_draft_max))`.
216    ///
217    /// # Examples
218    ///
219    /// ```ignore
220    /// let mut session = MtpSession::new(&target, &draft, 1, 3)?;
221    /// ```
222    ///
223    /// # Errors
224    ///
225    /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
226    pub fn new(
227        target: &LlamaContext<'_>,
228        draft: &LlamaContext<'_>,
229        n_seq: u32,
230        n_draft_max: i32,
231    ) -> Result<Self, MtpSessionError> {
232        Self::new_with_config(target, draft, MtpSessionConfig::new(n_seq, n_draft_max))
233    }
234
235    /// Construct an MTP draft session with full speculative draft parameters.
236    ///
237    /// `target` must be a [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) context.
238    /// `draft` must be a [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) context from the same model,
239    /// with [`LlamaContextParams::with_n_rs_seq`](crate::context::params::LlamaContextParams::with_n_rs_seq)
240    /// `>= config.n_draft_max`.
241    ///
242    /// # Examples
243    ///
244    /// ```ignore
245    /// let config = MtpSessionConfig::new(1, 1)
246    ///     .with_p_min(0.0); // match upstream default after #23269
247    /// let session = MtpSession::new_with_config(&target, &draft, config)?;
248    /// ```
249    ///
250    /// # Errors
251    ///
252    /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
253    pub fn new_with_config(
254        target: &LlamaContext<'_>,
255        draft: &LlamaContext<'_>,
256        config: MtpSessionConfig,
257    ) -> Result<Self, MtpSessionError> {
258        if config.n_seq == 0 {
259            return Err(MtpSessionError::InvalidConfig("n_seq must be > 0"));
260        }
261        if config.n_draft_max <= 0 {
262            return Err(MtpSessionError::InvalidConfig("n_draft_max must be > 0"));
263        }
264
265        let c_config = llama_cpp_sys_4::mtp_session_config {
266            n_seq: config.n_seq,
267            n_draft_max: config.n_draft_max,
268            n_min: config.n_min,
269            p_min: config.p_min,
270            spec_type: llama_cpp_sys_4::MTP_SPEC_TYPE_MTP.cast_signed(),
271        };
272
273        let raw = unsafe {
274            llama_cpp_sys_4::mtp_session_new(
275                target.context.as_ptr(),
276                draft.context.as_ptr(),
277                &raw const c_config,
278            )
279        };
280        let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
281        Ok(Self { raw, config })
282    }
283
284    /// Session configuration passed at construction.
285    #[must_use]
286    pub fn config(&self) -> MtpSessionConfig {
287        self.config
288    }
289
290    /// True when the speculative backend needs post-norm embeddings on the
291    /// target context (`llama_set_embeddings`).
292    ///
293    /// MTP returns **false**; use [`Self::need_embd_pre_norm`] for MTP.
294    #[must_use]
295    pub fn need_embd(&self) -> bool {
296        unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
297    }
298
299    /// True when the speculative backend needs pre-norm hidden states on the
300    /// target context (`llama_set_embeddings_pre_norm`).
301    ///
302    /// MTP returns **true**. Upstream configures this on both contexts during
303    /// session init; callers normally do not need to set it manually.
304    #[must_use]
305    pub fn need_embd_pre_norm(&self) -> bool {
306        unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
307    }
308
309    /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
310    /// call.
311    #[must_use]
312    pub fn n_draft_max(&self) -> i32 {
313        self.config.n_draft_max
314    }
315
316    /// Configured minimum draft tokens (`n_min`).
317    #[must_use]
318    pub fn n_min(&self) -> i32 {
319        self.config.n_min
320    }
321
322    /// Configured draft probability floor (`p_min`).
323    #[must_use]
324    pub fn p_min(&self) -> f32 {
325        self.config.p_min
326    }
327
328    /// Configured number of sequences.
329    #[must_use]
330    pub fn n_seq(&self) -> u32 {
331        self.config.n_seq
332    }
333
334    /// Log speculative-decoding statistics (draft/accept counts and timings) via
335    /// llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`] to
336    /// capture output.
337    ///
338    /// # Examples
339    ///
340    /// ```ignore
341    /// // After your generation loop:
342    /// session.print_stats();
343    /// ```
344    pub fn print_stats(&self) {
345        unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
346    }
347
348    /// Optional: call once at the start of a fresh generation with the
349    /// prompt tokens that were just decoded into the target context.
350    ///
351    /// Upstream uses this for prompt tracking; MTP speculative loops often
352    /// work without it if you call [`Self::process`] after every target decode.
353    ///
354    /// # Examples
355    ///
356    /// ```ignore
357    /// session.begin(0, &prompt_tokens)?;
358    /// ```
359    ///
360    /// # Errors
361    ///
362    /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is out of range.
363    pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
364        self.check_seq(seq_id)?;
365        unsafe {
366            llama_cpp_sys_4::mtp_session_begin(
367                self.raw.as_ptr(),
368                seq_id,
369                prompt.as_ptr().cast(),
370                prompt.len(),
371            );
372        }
373        Ok(())
374    }
375
376    /// Hand the session a batch that was just decoded on the target context.
377    ///
378    /// Call this after every successful `target.decode(batch)` so upstream can
379    /// sync draft recurrent state with the target KV cache.
380    ///
381    /// # Examples
382    ///
383    /// ```ignore
384    /// target.decode(&mut batch)?;
385    /// session.process(&batch)?;
386    /// ```
387    ///
388    /// # Errors
389    ///
390    /// Returns [`MtpSessionError::Process`] when upstream rejects the batch.
391    pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
392        let ok = unsafe {
393            llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &raw const batch.llama_batch)
394        };
395        if ok {
396            Ok(())
397        } else {
398            Err(MtpSessionError::Process)
399        }
400    }
401
402    /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
403    ///
404    /// `n_past` is the number of tokens already in the target KV cache for
405    /// `seq_id`. `id_last` is the last token accepted on the target (usually
406    /// the token you just sampled).
407    ///
408    /// # Examples
409    ///
410    /// ```ignore
411    /// let drafts = session.draft(0, n_past, last_token)?;
412    /// for draft in &drafts {
413    ///     // verify each draft against target logits ...
414    /// }
415    /// ```
416    ///
417    /// # Errors
418    ///
419    /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is out of range.
420    pub fn draft(
421        &mut self,
422        seq_id: i32,
423        n_past: i32,
424        id_last: LlamaToken,
425    ) -> Result<Vec<LlamaToken>, MtpSessionError> {
426        self.check_seq(seq_id)?;
427
428        let cap = usize::try_from(self.config.n_draft_max.max(0)).unwrap_or(0);
429        let mut buf: Vec<i32> = vec![0; cap];
430        let mut out_n = i32::try_from(cap).unwrap_or(i32::MAX);
431
432        unsafe {
433            llama_cpp_sys_4::mtp_session_draft(
434                self.raw.as_ptr(),
435                seq_id,
436                n_past,
437                id_last.0,
438                buf.as_mut_ptr(),
439                &raw mut out_n,
440            );
441        }
442
443        let n = usize::try_from(out_n.max(0)).unwrap_or(0);
444        buf.truncate(n);
445        Ok(buf.into_iter().map(LlamaToken).collect())
446    }
447
448    /// Inform the session how many draft tokens the target verifier accepted.
449    ///
450    /// Pass `0` when every draft was rejected. Upstream rolls back draft
451    /// recurrent state accordingly.
452    ///
453    /// # Examples
454    ///
455    /// ```ignore
456    /// session.accept(0, n_accepted)?;
457    /// ```
458    ///
459    /// # Errors
460    ///
461    /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is out of range.
462    pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
463        self.check_seq(seq_id)?;
464        unsafe {
465            llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
466        }
467        Ok(())
468    }
469
470    fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
471        if seq_id < 0 || seq_id.cast_unsigned() >= self.config.n_seq {
472            return Err(MtpSessionError::BadSeqId {
473                seq_id,
474                n_seq: self.config.n_seq,
475            });
476        }
477        Ok(())
478    }
479}
480
481impl Drop for MtpSession {
482    fn drop(&mut self) {
483        unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
484    }
485}
486
487impl std::fmt::Debug for MtpSession {
488    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489        f.debug_struct("MtpSession")
490            .field("config", &self.config)
491            .field("need_embd_pre_norm", &self.need_embd_pre_norm())
492            .finish_non_exhaustive()
493    }
494}