llama_cpp_4/mtp.rs
1//! Safe wrapper around the C++ MTP draft session.
2//!
3//! [`MtpSession`] pairs a target [`LlamaContext`] with an MTP draft
4//! [`LlamaContext`] (built with
5//! [`crate::context::params::LlamaContextType::Mtp`]) and drives the
6//! multi-token-prediction speculative-decoding loop introduced in upstream
7//! llama.cpp [PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673).
8//!
9//! The actual draft algorithm lives in upstream's
10//! `common/speculative.cpp` (`common_speculative_state_draft_mtp`); this
11//! module is a thin Rust safe wrapper around a small C++ shim in
12//! `llama-cpp-sys-4/mtp_shim/` that re-exposes that C++ class with C linkage.
13//!
14//! # Usage outline
15//!
16//! ```ignore
17//! // Build the target context (default) and the MTP draft context.
18//! let target = model.new_context(&backend, LlamaContextParams::default())?;
19//! let draft = model.new_context(
20//! &backend,
21//! LlamaContextParams::default()
22//! .with_ctx_type(LlamaContextType::Mtp)
23//! .with_n_rs_seq(4),
24//! )?;
25//!
26//! let mut sess = MtpSession::new(&target, &draft, 1, 3)?;
27//!
28//! // After every llama_decode on the target context, hand the batch to MTP:
29//! sess.process(&target_batch)?;
30//!
31//! // Then ask for a draft starting from the last sampled token:
32//! let drafts = sess.draft(0, n_past, last_token)?;
33//!
34//! // Verify against target, decide how many to accept, then:
35//! sess.accept(0, n_accepted as u16)?;
36//! ```
37
38use std::ptr::NonNull;
39
40use crate::context::LlamaContext;
41use crate::llama_batch::LlamaBatch;
42use crate::token::LlamaToken;
43
44/// Errors raised by the MTP draft session.
45#[derive(Debug, thiserror::Error)]
46pub enum MtpSessionError {
47 /// Returned when `mtp_session_new` fails (typically: model lacks MTP heads,
48 /// or one of the contexts is incompatible).
49 #[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
50 Init,
51
52 /// `mtp_session_process` returned false.
53 #[error("mtp_session_process failed (see llama.cpp logs)")]
54 Process,
55
56 /// Caller passed a sequence id outside `[0, n_seq)`.
57 #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
58 BadSeqId {
59 /// the offending seq id
60 seq_id: i32,
61 /// configured number of sequences
62 n_seq: u32,
63 },
64}
65
66/// Owned MTP draft session.
67///
68/// Drops the underlying `mtp_session *` (and the C++ `common_speculative *`
69/// it holds) when freed.
70///
71/// # Lifetime contract (manual)
72///
73/// The session holds raw pointers to both the target and draft
74/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
75/// drop them) for as long as the session exists.** This contract is not
76/// enforced by the borrow checker — the session does not hold Rust borrows of
77/// the contexts, because both contexts must remain individually mutable
78/// (you'll be calling `target.decode(...)` while the session exists, and the
79/// session also mutates the draft context internally).
80///
81/// Dropping a context that the session still references is undefined
82/// behaviour at the C++ level (use-after-free inside `common_speculative_*`).
83pub struct MtpSession {
84 raw: NonNull<llama_cpp_sys_4::mtp_session>,
85 n_seq: u32,
86 n_draft_max: i32,
87}
88
89// SAFETY: the underlying C++ session owns its own state and is not tied to
90// any TLS. Concurrent calls from multiple threads are NOT safe (it mutates
91// internal buffers without locking) — that's modelled by `&mut self` on the
92// mutating methods.
93unsafe impl Send for MtpSession {}
94
95impl MtpSession {
96 /// Construct an MTP draft session.
97 ///
98 /// `target` must be a `LlamaContextType::Default` context.
99 /// `draft` must be a `LlamaContextType::Mtp` context built from the same
100 /// model and configured with `with_n_rs_seq(>= n_draft_max)`.
101 ///
102 /// `n_seq` is the number of concurrent sequences (1 for a single
103 /// conversation). `n_draft_max` caps the number of tokens drafted per
104 /// round (commonly 3 for Qwen3.6 MTP).
105 ///
106 /// # Errors
107 ///
108 /// Returns [`MtpSessionError::Init`] if upstream rejects the
109 /// configuration (e.g. the model has no MTP heads).
110 pub fn new(
111 target: &LlamaContext<'_>,
112 draft: &LlamaContext<'_>,
113 n_seq: u32,
114 n_draft_max: i32,
115 ) -> Result<Self, MtpSessionError> {
116 let raw = unsafe {
117 llama_cpp_sys_4::mtp_session_new(
118 target.context.as_ptr(),
119 draft.context.as_ptr(),
120 n_seq,
121 n_draft_max,
122 )
123 };
124 let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
125 Ok(Self {
126 raw,
127 n_seq,
128 n_draft_max,
129 })
130 }
131
132 /// True if MTP requires embeddings to be extractable from the target
133 /// context. For MTP this is always true — exposed for symmetry with
134 /// upstream's `common_speculative_need_embd`.
135 #[must_use]
136 pub fn need_embd(&self) -> bool {
137 unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
138 }
139
140 /// Configured maximum number of tokens drafted per [`draft`](Self::draft)
141 /// call.
142 #[must_use]
143 pub fn n_draft_max(&self) -> i32 {
144 self.n_draft_max
145 }
146
147 /// Configured number of sequences.
148 #[must_use]
149 pub fn n_seq(&self) -> u32 {
150 self.n_seq
151 }
152
153 /// Optional: call once at the start of a fresh generation with the
154 /// prompt tokens that were just decoded into the target context.
155 pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
156 self.check_seq(seq_id)?;
157 unsafe {
158 llama_cpp_sys_4::mtp_session_begin(
159 self.raw.as_ptr(),
160 seq_id,
161 prompt.as_ptr().cast(),
162 prompt.len(),
163 );
164 }
165 Ok(())
166 }
167
168 /// Hand the session a batch that was just decoded on the target context.
169 /// MTP needs to see every target batch (prompt prefill + each
170 /// verification step) to keep its per-sequence pre-norm-embedding
171 /// carryover in sync.
172 ///
173 /// # Errors
174 ///
175 /// Returns [`MtpSessionError::Process`] if upstream rejects the batch
176 /// (most often: the batch carries `embd` directly rather than tokens).
177 pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
178 let ok =
179 unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
180 if ok {
181 Ok(())
182 } else {
183 Err(MtpSessionError::Process)
184 }
185 }
186
187 /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens
188 /// for sequence `seq_id`, starting from `id_last` at position `n_past`.
189 ///
190 /// Returns an owned `Vec<LlamaToken>` of length `<= n_draft_max`.
191 ///
192 /// # Errors
193 ///
194 /// Returns [`MtpSessionError::BadSeqId`] if `seq_id` is outside the
195 /// configured `n_seq` range.
196 pub fn draft(
197 &mut self,
198 seq_id: i32,
199 n_past: i32,
200 id_last: LlamaToken,
201 ) -> Result<Vec<LlamaToken>, MtpSessionError> {
202 self.check_seq(seq_id)?;
203
204 let cap = self.n_draft_max.max(0) as usize;
205 let mut buf: Vec<i32> = vec![0; cap];
206 let mut out_n: i32 = cap as i32;
207
208 unsafe {
209 llama_cpp_sys_4::mtp_session_draft(
210 self.raw.as_ptr(),
211 seq_id,
212 n_past,
213 id_last.0,
214 buf.as_mut_ptr(),
215 &mut out_n,
216 );
217 }
218
219 let n = out_n.max(0) as usize;
220 buf.truncate(n);
221 Ok(buf.into_iter().map(LlamaToken).collect())
222 }
223
224 /// Inform the session that `n_accepted` tokens from the last draft were
225 /// accepted by the target verifier. This is required after every
226 /// [`draft`](Self::draft) call to keep the draft context's recurrent
227 /// state consistent.
228 pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
229 self.check_seq(seq_id)?;
230 unsafe {
231 llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
232 }
233 Ok(())
234 }
235
236 fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
237 if seq_id < 0 || (seq_id as u32) >= self.n_seq {
238 return Err(MtpSessionError::BadSeqId {
239 seq_id,
240 n_seq: self.n_seq,
241 });
242 }
243 Ok(())
244 }
245}
246
247impl Drop for MtpSession {
248 fn drop(&mut self) {
249 unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
250 }
251}
252
253impl std::fmt::Debug for MtpSession {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("MtpSession")
256 .field("n_seq", &self.n_seq)
257 .field("n_draft_max", &self.n_draft_max)
258 .finish()
259 }
260}