Skip to main content

llama_cpp_4/
mtp.rs

1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_impl_draft_mtp`). This module
11//! wraps it through a stable C shim in `llama-cpp-sys-4/mtp_shim/`.
12//!
13//! # Upstream behaviour (llama.cpp #23269+)
14//!
15//! After [MTP clean-up #23269](https://github.com/ggml-org/llama.cpp/pull/23269):
16//!
17//! - Draft sampling uses `top_k = 10` inside upstream (not configurable from Rust).
18//! - [`MtpSessionConfig::p_min`] filters low-confidence draft tokens (default `0.0`).
19//! - Upstream CLI default for `n_max` is `3`; set [`MtpSessionConfig::n_draft_max`]
20//!   explicitly — optimal values are model/quant dependent ([`MTP.md`] on GitHub).
21//!
22//! [`MTP.md`]: https://github.com/eugenehp/llama-cpp-rs/blob/main/MTP.md
23//!
24//! # Quick start
25//!
26//! ```ignore
27//! use llama_cpp_4::context::params::{LlamaContextParams, LlamaContextType};
28//! use llama_cpp_4::mtp::{MtpSession, MtpSessionConfig};
29//!
30//! let n_draft_max = 3;
31//!
32//! let target = model.new_context(&backend, LlamaContextParams::default())?;
33//! let draft = model.new_context(
34//!     &backend,
35//!     LlamaContextParams::default()
36//!         .with_ctx_type(LlamaContextType::Mtp)
37//!         .with_n_rs_seq(n_draft_max.max(4)),
38//! )?;
39//!
40//! let config = MtpSessionConfig::new(1, n_draft_max).with_p_min(0.0);
41//! let mut session = MtpSession::new_with_config(&target, &draft, config)?;
42//! ```
43//!
44//! # Speculative loop
45//!
46//! For each generation step, after decoding on the **target** context:
47//!
48//! ```ignore
49//! // 1. Target prefill or verify decode (you build the batch)
50//! target.decode(&mut batch)?;
51//!
52//! // 2. Tell MTP about the batch just decoded on the target
53//! session.process(&batch)?;
54//!
55//! // 3. Ask for draft tokens starting from the last accepted token
56//! let drafts = session.draft(0, n_past, last_token)?;
57//!
58//! // 4. Verify drafts on the target (compare logits / sample — your code)
59//! let n_accepted: u16 = /* ... */;
60//!
61//! // 5. Sync draft recurrent state with what the target accepted
62//! session.accept(0, n_accepted)?;
63//! ```
64//!
65//! Call [`MtpSession::begin`] once per fresh generation if you want upstream
66//! prompt tracking (optional for MTP). Call [`MtpSession::print_stats`] when
67//! finished to log draft/accept counters via llama.cpp's log callback.
68//!
69//! A full runnable implementation is in `examples/mtp/`.
70//!
71//! # Embedding requirements
72//!
73//! | Method | MTP typical value | Meaning |
74//! |---|---|---|
75//! | [`MtpSession::need_embd_pre_norm`] | `true` | Next-n hidden states (upstream name) |
76//! | [`MtpSession::need_embd`] | `false` | Post-norm / seq embeddings not used |
77//!
78//! Rust keeps `*_pre_norm` names; upstream C API uses `*_nextn` since the Jun 2026
79//! llama.cpp bump. Session init configures extraction on both contexts automatically;
80//! manual [`LlamaContext::set_embeddings_pre_norm`] is rarely needed.
81
82use std::ptr::NonNull;
83
84use crate::context::LlamaContext;
85use crate::llama_batch::LlamaBatch;
86use crate::token::LlamaToken;
87
88/// Errors raised by the MTP draft session.
89#[derive(Debug, thiserror::Error)]
90pub enum MtpSessionError {
91    /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
92    /// or one of the contexts is incompatible).
93    #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
94    Init,
95
96    /// `mtp_session_process` returned false.
97    #[error("mtp_session_process failed (see llama.cpp logs)")]
98    Process,
99
100    /// Caller passed a sequence id outside `[0, n_seq)`.
101    #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
102    BadSeqId {
103        /// the offending seq id
104        seq_id: i32,
105        /// configured number of sequences
106        n_seq: u32,
107    },
108
109    /// Invalid session configuration (e.g. `n_draft_max <= 0`).
110    #[error("invalid MTP session config: {0}")]
111    InvalidConfig(&'static str),
112}
113
114/// Parameters for [`MtpSession::new_with_config`].
115///
116/// Maps directly to upstream `common_params_speculative_draft`.
117///
118/// # Examples
119///
120/// ```ignore
121/// // Defaults: n_min = 0, p_min = 0.0 (aligned with upstream #23269+)
122/// let cfg = MtpSessionConfig::new(1, 3);
123///
124/// // Stricter drafts: skip tokens below 10% draft-model probability
125/// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
126/// ```
127#[derive(Debug, Clone, Copy, PartialEq)]
128pub struct MtpSessionConfig {
129    /// Number of concurrent sequences (usually `1`).
130    pub n_seq: u32,
131    /// Maximum tokens drafted per [`MtpSession::draft`] call (`n_max` upstream).
132    pub n_draft_max: i32,
133    /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
134    pub n_min: i32,
135    /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
136    pub p_min: f32,
137}
138
139impl MtpSessionConfig {
140    /// Build config with upstream-aligned defaults for `n_min` (`0`) and `p_min` (`0.0`).
141    ///
142    /// # Examples
143    ///
144    /// ```ignore
145    /// let cfg = MtpSessionConfig::new(1, 3); // one sequence, up to 3 draft tokens
146    /// ```
147    #[must_use]
148    pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
149        Self {
150            n_seq,
151            n_draft_max,
152            n_min: 0,
153            p_min: 0.0,
154        }
155    }
156
157    /// Set minimum draft tokens (`n_min` upstream).
158    #[must_use]
159    pub fn with_n_min(mut self, n_min: i32) -> Self {
160        self.n_min = n_min;
161        self
162    }
163
164    /// Set draft probability floor (`p_min` upstream).
165    ///
166    /// Draft tokens whose greedy probability falls below this value are dropped.
167    /// Upstream default is `0.0` after #23269 (was `0.75` in older builds).
168    ///
169    /// # Examples
170    ///
171    /// ```ignore
172    /// let cfg = MtpSessionConfig::new(1, 1).with_p_min(0.10);
173    /// ```
174    #[must_use]
175    pub fn with_p_min(mut self, p_min: f32) -> Self {
176        self.p_min = p_min;
177        self
178    }
179}
180
181/// Owned MTP draft session.
182///
183/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
184/// it holds) when freed.
185///
186/// # Lifetime contract (manual)
187///
188/// The session holds raw pointers to both the target and draft
189/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
190/// drop them) for as long as the session exists.**
191pub struct MtpSession {
192    raw: NonNull<llama_cpp_sys_4::mtp_session>,
193    config: MtpSessionConfig,
194}
195
196// SAFETY: the underlying C++ session owns its own state and is not tied to
197// any TLS. Concurrent calls from multiple threads are NOT safe.
198unsafe impl Send for MtpSession {}
199
200impl MtpSession {
201    /// Construct an MTP draft session with upstream defaults for `n_min` and
202    /// `p_min`.
203    ///
204    /// Equivalent to `new_with_config(MtpSessionConfig::new(n_seq, n_draft_max))`.
205    ///
206    /// # Examples
207    ///
208    /// ```ignore
209    /// let mut session = MtpSession::new(&target, &draft, 1, 3)?;
210    /// ```
211    ///
212    /// # Errors
213    ///
214    /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
215    pub fn new(
216        target: &LlamaContext<'_>,
217        draft: &LlamaContext<'_>,
218        n_seq: u32,
219        n_draft_max: i32,
220    ) -> Result<Self, MtpSessionError> {
221        Self::new_with_config(target, draft, MtpSessionConfig::new(n_seq, n_draft_max))
222    }
223
224    /// Construct an MTP draft session with full speculative draft parameters.
225    ///
226    /// `target` must be a [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) context.
227    /// `draft` must be a [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) context from the same model,
228    /// with [`LlamaContextParams::with_n_rs_seq`](crate::context::params::LlamaContextParams::with_n_rs_seq)
229    /// `>= config.n_draft_max`.
230    ///
231    /// # Examples
232    ///
233    /// ```ignore
234    /// let config = MtpSessionConfig::new(1, 1)
235    ///     .with_p_min(0.0); // match upstream default after #23269
236    /// let session = MtpSession::new_with_config(&target, &draft, config)?;
237    /// ```
238    ///
239    /// # Errors
240    ///
241    /// Returns [`MtpSessionError::Init`] or [`MtpSessionError::InvalidConfig`].
242    pub fn new_with_config(
243        target: &LlamaContext<'_>,
244        draft: &LlamaContext<'_>,
245        config: MtpSessionConfig,
246    ) -> Result<Self, MtpSessionError> {
247        if config.n_seq == 0 {
248            return Err(MtpSessionError::InvalidConfig("n_seq must be > 0"));
249        }
250        if config.n_draft_max <= 0 {
251            return Err(MtpSessionError::InvalidConfig("n_draft_max must be > 0"));
252        }
253
254        let c_config = llama_cpp_sys_4::mtp_session_config {
255            n_seq: config.n_seq,
256            n_draft_max: config.n_draft_max,
257            n_min: config.n_min,
258            p_min: config.p_min,
259        };
260
261        let raw = unsafe {
262            llama_cpp_sys_4::mtp_session_new(
263                target.context.as_ptr(),
264                draft.context.as_ptr(),
265                &raw const c_config,
266            )
267        };
268        let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
269        Ok(Self { raw, config })
270    }
271
272    /// Session configuration passed at construction.
273    #[must_use]
274    pub fn config(&self) -> MtpSessionConfig {
275        self.config
276    }
277
278    /// True when the speculative backend needs post-norm embeddings on the
279    /// target context (`llama_set_embeddings`).
280    ///
281    /// MTP returns **false**; use [`Self::need_embd_pre_norm`] for MTP.
282    #[must_use]
283    pub fn need_embd(&self) -> bool {
284        unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
285    }
286
287    /// True when the speculative backend needs pre-norm hidden states on the
288    /// target context (`llama_set_embeddings_pre_norm`).
289    ///
290    /// MTP returns **true**. Upstream configures this on both contexts during
291    /// session init; callers normally do not need to set it manually.
292    #[must_use]
293    pub fn need_embd_pre_norm(&self) -> bool {
294        unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
295    }
296
297    /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
298    /// call.
299    #[must_use]
300    pub fn n_draft_max(&self) -> i32 {
301        self.config.n_draft_max
302    }
303
304    /// Configured minimum draft tokens (`n_min`).
305    #[must_use]
306    pub fn n_min(&self) -> i32 {
307        self.config.n_min
308    }
309
310    /// Configured draft probability floor (`p_min`).
311    #[must_use]
312    pub fn p_min(&self) -> f32 {
313        self.config.p_min
314    }
315
316    /// Configured number of sequences.
317    #[must_use]
318    pub fn n_seq(&self) -> u32 {
319        self.config.n_seq
320    }
321
322    /// Log speculative-decoding statistics (draft/accept counts and timings) via
323    /// llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`] to
324    /// capture output.
325    ///
326    /// # Examples
327    ///
328    /// ```ignore
329    /// // After your generation loop:
330    /// session.print_stats();
331    /// ```
332    pub fn print_stats(&self) {
333        unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
334    }
335
336    /// Optional: call once at the start of a fresh generation with the
337    /// prompt tokens that were just decoded into the target context.
338    ///
339    /// Upstream uses this for prompt tracking; MTP speculative loops often
340    /// work without it if you call [`Self::process`] after every target decode.
341    ///
342    /// # Examples
343    ///
344    /// ```ignore
345    /// session.begin(0, &prompt_tokens)?;
346    /// ```
347    pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
348        self.check_seq(seq_id)?;
349        unsafe {
350            llama_cpp_sys_4::mtp_session_begin(
351                self.raw.as_ptr(),
352                seq_id,
353                prompt.as_ptr().cast(),
354                prompt.len(),
355            );
356        }
357        Ok(())
358    }
359
360    /// Hand the session a batch that was just decoded on the target context.
361    ///
362    /// Call this after every successful `target.decode(batch)` so upstream can
363    /// sync draft recurrent state with the target KV cache.
364    ///
365    /// # Examples
366    ///
367    /// ```ignore
368    /// target.decode(&mut batch)?;
369    /// session.process(&batch)?;
370    /// ```
371    pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
372        let ok =
373            unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
374        if ok {
375            Ok(())
376        } else {
377            Err(MtpSessionError::Process)
378        }
379    }
380
381    /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
382    ///
383    /// `n_past` is the number of tokens already in the target KV cache for
384    /// `seq_id`. `id_last` is the last token accepted on the target (usually
385    /// the token you just sampled).
386    ///
387    /// # Examples
388    ///
389    /// ```ignore
390    /// let drafts = session.draft(0, n_past, last_token)?;
391    /// for draft in &drafts {
392    ///     // verify each draft against target logits ...
393    /// }
394    /// ```
395    pub fn draft(
396        &mut self,
397        seq_id: i32,
398        n_past: i32,
399        id_last: LlamaToken,
400    ) -> Result<Vec<LlamaToken>, MtpSessionError> {
401        self.check_seq(seq_id)?;
402
403        let cap = self.config.n_draft_max.max(0) as usize;
404        let mut buf: Vec<i32> = vec![0; cap];
405        let mut out_n: i32 = cap as i32;
406
407        unsafe {
408            llama_cpp_sys_4::mtp_session_draft(
409                self.raw.as_ptr(),
410                seq_id,
411                n_past,
412                id_last.0,
413                buf.as_mut_ptr(),
414                &mut out_n,
415            );
416        }
417
418        let n = out_n.max(0) as usize;
419        buf.truncate(n);
420        Ok(buf.into_iter().map(LlamaToken).collect())
421    }
422
423    /// Inform the session how many draft tokens the target verifier accepted.
424    ///
425    /// Pass `0` when every draft was rejected. Upstream rolls back draft
426    /// recurrent state accordingly.
427    ///
428    /// # Examples
429    ///
430    /// ```ignore
431    /// session.accept(0, n_accepted)?;
432    /// ```
433    pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
434        self.check_seq(seq_id)?;
435        unsafe {
436            llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
437        }
438        Ok(())
439    }
440
441    fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
442        if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
443            return Err(MtpSessionError::BadSeqId {
444                seq_id,
445                n_seq: self.config.n_seq,
446            });
447        }
448        Ok(())
449    }
450}
451
452impl Drop for MtpSession {
453    fn drop(&mut self) {
454        unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
455    }
456}
457
458impl std::fmt::Debug for MtpSession {
459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        f.debug_struct("MtpSession")
461            .field("config", &self.config)
462            .field("need_embd_pre_norm", &self.need_embd_pre_norm())
463            .finish()
464    }
465}