Skip to main content

llama_cpp_4/
mtp.rs

1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The actual draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_state_draft_mtp`); this
11//! module is a thin Rust safe wrapper around a small C++ shim in
12//! `llama-cpp-sys-4/mtp_shim/` that re-exposes that C++ class with C linkage.
13//!
14//! # Usage outline
15//!
16//! ```ignore
17//! // Build the target context (default) and the MTP draft context.
18//! let target = model.new_context(&backend, LlamaContextParams::default())?;
19//! let draft  = model.new_context(
20//!     &backend,
21//!     LlamaContextParams::default()
22//!         .with_ctx_type(LlamaContextType::Mtp)
23//!         .with_n_rs_seq(4),
24//! )?;
25//!
26//! let mut sess = MtpSession::new(&target, &draft, 1, 3)?;
27//!
28//! // After every llama_decode on the target context, hand the batch to MTP:
29//! sess.process(&target_batch)?;
30//!
31//! // Then ask for a draft starting from the last sampled token:
32//! let drafts = sess.draft(0, n_past, last_token)?;
33//!
34//! // Verify against target, decide how many to accept, then:
35//! sess.accept(0, n_accepted as u16)?;
36//! ```
37
38use std::ptr::NonNull;
39
40use crate::context::LlamaContext;
41use crate::llama_batch::LlamaBatch;
42use crate::token::LlamaToken;
43
44/// Errors raised by the MTP draft session.
45#[derive(Debug, thiserror::Error)]
46pub enum MtpSessionError {
47    /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
48    /// or one of the contexts is incompatible).
49    #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
50    Init,
51
52    /// `mtp_session_process` returned false.
53    #[error("mtp_session_process failed (see llama.cpp logs)")]
54    Process,
55
56    /// Caller passed a sequence id outside `[0, n_seq)`.
57    #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
58    BadSeqId {
59        /// the offending seq id
60        seq_id: i32,
61        /// configured number of sequences
62        n_seq: u32,
63    },
64}
65
66/// Owned MTP draft session.
67///
68/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
69/// it holds) when freed.
70///
71/// # Lifetime contract (manual)
72///
73/// The session holds raw pointers to both the target and draft
74/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
75/// drop them) for as long as the session exists.** This contract is not
76/// enforced by the borrow checker — the session does not hold Rust borrows of
77/// the contexts, because both contexts must remain individually mutable
78/// (you'll be calling `target.decode(...)` while the session exists, and the
79/// session also mutates the draft context internally).
80///
81/// Dropping a context that the session still references is undefined
82/// behaviour at the C++ level (use-after-free inside `common_speculative_*`).
83pub struct MtpSession {
84    raw: NonNull<llama_cpp_sys_4::mtp_session>,
85    n_seq: u32,
86    n_draft_max: i32,
87}
88
89// SAFETY: the underlying C++ session owns its own state and is not tied to
90// any TLS. Concurrent calls from multiple threads are NOT safe (it mutates
91// internal buffers without locking) — that's modelled by `&mut self` on the
92// mutating methods.
93unsafe impl Send for MtpSession {}
94
95impl MtpSession {
96    /// Construct an MTP draft session.
97    ///
98    /// `target` must be a `LlamaContextType::Default` context.
99    /// `draft` must be a `LlamaContextType::Mtp` context built from the same
100    /// model and configured with `with_n_rs_seq(>= n_draft_max)`.
101    ///
102    /// `n_seq` is the number of concurrent sequences (1 for a single
103    /// conversation). `n_draft_max` caps the number of tokens drafted per
104    /// round (commonly 3 for Qwen3.6 MTP).
105    ///
106    /// # Errors
107    ///
108    /// Returns [`MtpSessionError::Init`] if upstream rejects the
109    /// configuration (e.g. the model has no MTP heads).
110    pub fn new(
111        target: &LlamaContext<'_>,
112        draft: &LlamaContext<'_>,
113        n_seq: u32,
114        n_draft_max: i32,
115    ) -> Result<Self, MtpSessionError> {
116        let raw = unsafe {
117            llama_cpp_sys_4::mtp_session_new(
118                target.context.as_ptr(),
119                draft.context.as_ptr(),
120                n_seq,
121                n_draft_max,
122            )
123        };
124        let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
125        Ok(Self {
126            raw,
127            n_seq,
128            n_draft_max,
129        })
130    }
131
132    /// True if MTP requires embeddings to be extractable from the target
133    /// context. For MTP this is always true — exposed for symmetry with
134    /// upstream's `common_speculative_need_embd`.
135    #[must_use]
136    pub fn need_embd(&self) -> bool {
137        unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
138    }
139
140    /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
141    /// call.
142    #[must_use]
143    pub fn n_draft_max(&self) -> i32 {
144        self.n_draft_max
145    }
146
147    /// Configured number of sequences.
148    #[must_use]
149    pub fn n_seq(&self) -> u32 {
150        self.n_seq
151    }
152
153    /// Optional: call once at the start of a fresh generation with the
154    /// prompt tokens that were just decoded into the target context.
155    pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
156        self.check_seq(seq_id)?;
157        unsafe {
158            llama_cpp_sys_4::mtp_session_begin(
159                self.raw.as_ptr(),
160                seq_id,
161                prompt.as_ptr().cast(),
162                prompt.len(),
163            );
164        }
165        Ok(())
166    }
167
168    /// Hand the session a batch that was just decoded on the target context.
169    /// MTP needs to see every target batch (prompt prefill + each
170    /// verification step) to keep its per-sequence pre-norm-embedding
171    /// carryover in sync.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`MtpSessionError::Process`] if upstream rejects the batch
176    /// (most often: the batch carries `embd` directly rather than tokens).
177    pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
178        let ok = unsafe {
179            llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch)
180        };
181        if ok {
182            Ok(())
183        } else {
184            Err(MtpSessionError::Process)
185        }
186    }
187
188    /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens
189    /// for sequence `seq_id`, starting from `id_last` at position `n_past`.
190    ///
191    /// Returns an owned `Vec<LlamaToken>` of length `<= n_draft_max`.
192    ///
193    /// # Errors
194    ///
195    /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is outside the
196    /// configured `n_seq` range.
197    pub fn draft(
198        &mut self,
199        seq_id: i32,
200        n_past: i32,
201        id_last: LlamaToken,
202    ) -> Result<Vec<LlamaToken>, MtpSessionError> {
203        self.check_seq(seq_id)?;
204
205        let cap = self.n_draft_max.max(0) as usize;
206        let mut buf: Vec<i32> = vec![0; cap];
207        let mut out_n: i32 = cap as i32;
208
209        unsafe {
210            llama_cpp_sys_4::mtp_session_draft(
211                self.raw.as_ptr(),
212                seq_id,
213                n_past,
214                id_last.0,
215                buf.as_mut_ptr(),
216                &mut out_n,
217            );
218        }
219
220        let n = out_n.max(0) as usize;
221        buf.truncate(n);
222        Ok(buf.into_iter().map(LlamaToken).collect())
223    }
224
225    /// Inform the session that `n_accepted` tokens from the last draft were
226    /// accepted by the target verifier. This is required after every
227    /// [`draft`](Self::draft) call to keep the draft context's recurrent
228    /// state consistent.
229    pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
230        self.check_seq(seq_id)?;
231        unsafe {
232            llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
233        }
234        Ok(())
235    }
236
237    fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
238        if seq_id < 0 || (seq_id as u32) >= self.n_seq {
239            return Err(MtpSessionError::BadSeqId {
240                seq_id,
241                n_seq: self.n_seq,
242            });
243        }
244        Ok(())
245    }
246}
247
248impl Drop for MtpSession {
249    fn drop(&mut self) {
250        unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
251    }
252}
253
254impl std::fmt::Debug for MtpSession {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        f.debug_struct("MtpSession")
257            .field("n_seq", &self.n_seq)
258            .field("n_draft_max", &self.n_draft_max)
259            .finish()
260    }
261}