llama_cpp_4/eagle.rs
1//! Safe wrapper around the C++ EAGLE-3 draft session.
2//!
3//! [`Eagle3Session`] drives **EAGLE-3** speculative decoding
4//! (`COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3` in upstream llama.cpp). EAGLE-3
5//! pairs a target model with a small, separately-trained **EAGLE-3 draft
6//! model** that predicts the next tokens from hidden states extracted out of
7//! the target model.
8//!
9//! The draft algorithm lives in upstream's `common/speculative.cpp`
10//! (`common_speculative_impl_draft_eagle3`). This module wraps it through the
11//! same stable C shim used for MTP (`llama-cpp-sys-4/mtp_shim/`); the two
12//! techniques share an identical session lifecycle and differ only in how the
13//! draft context is built.
14//!
15//! # EAGLE-3 vs MTP
16//!
17//! | | EAGLE-3 ([`Eagle3Session`]) | MTP ([`crate::mtp::MtpSession`]) |
18//! |---|---|---|
19//! | Draft weights | a **separate** EAGLE-3 draft model | the **same** model as the target |
20//! | Draft context type | [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) | [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) |
21//! | Requirement | draft model must expose 3 target-extract layers | target model must have MTP heads |
22//!
23//! # Setup
24//!
25//! ```ignore
26//! use llama_cpp_4::context::params::LlamaContextParams;
27//! use llama_cpp_4::eagle::{Eagle3Session, Eagle3SessionConfig};
28//!
29//! let n_draft_max = 3;
30//!
31//! // Target: the main model, a normal (default) context.
32//! let target = main_model.new_context(&backend, LlamaContextParams::default())?;
33//!
34//! // Draft: a SEPARATE EAGLE-3 draft model, also a default context.
35//! let draft = eagle3_model.new_context(&backend, LlamaContextParams::default())?;
36//!
37//! let config = Eagle3SessionConfig::new(1, n_draft_max);
38//! let mut session = Eagle3Session::new_with_config(&target, &draft, config)?;
39//! ```
40//!
41//! # Speculative loop
42//!
43//! Identical in shape to MTP: after each decode on the **target** context call
44//! [`process`](Eagle3Session::process), then [`draft`](Eagle3Session::draft)
45//! to get candidate tokens, verify them on the target, and report how many
46//! were accepted with [`accept`](Eagle3Session::accept).
47//!
48//! ```ignore
49//! target.decode(&mut batch)?;
50//! session.process(&batch)?;
51//! let drafts = session.draft(0, n_past, last_token)?;
52//! // verify `drafts` against the target, count acceptances ...
53//! session.accept(0, n_accepted)?;
54//! ```
55//!
56//! # Hidden-state extraction
57//!
58//! EAGLE-3 needs the target model to expose internal hidden states. The
59//! session configures the required extraction on both contexts at construction
60//! time; [`need_embd`](Eagle3Session::need_embd) and
61//! [`need_embd_pre_norm`](Eagle3Session::need_embd_pre_norm) report which kind
62//! the active backend requested (rarely needed by callers).
63
64use std::ptr::NonNull;
65
66use crate::context::LlamaContext;
67use crate::llama_batch::LlamaBatch;
68use crate::token::LlamaToken;
69
70/// Errors raised by the EAGLE-3 draft session.
71#[derive(Debug, thiserror::Error)]
72pub enum Eagle3SessionError {
73 /// Returned when session init fails. The most common cause is that `draft`
74 /// was not built from a valid EAGLE-3 draft model (upstream expects a draft
75 /// model exposing exactly 3 target-extract layers), or that one of the
76 /// contexts is incompatible.
77 #[error("failed to create EAGLE-3 draft session — check that `draft` is a context over a valid EAGLE-3 draft model (3 extract layers) built from the same target")]
78 Init,
79
80 /// `process` returned false on the underlying speculative context.
81 #[error("EAGLE-3 process failed (see llama.cpp logs)")]
82 Process,
83
84 /// Caller passed a sequence id outside `[0, n_seq)`.
85 #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
86 BadSeqId {
87 /// the offending seq id
88 seq_id: i32,
89 /// configured number of sequences
90 n_seq: u32,
91 },
92
93 /// Invalid session configuration (e.g. `n_draft_max <= 0`).
94 #[error("invalid EAGLE-3 session config: {0}")]
95 InvalidConfig(&'static str),
96}
97
98/// Parameters for [`Eagle3Session::new_with_config`].
99///
100/// Maps directly to upstream `common_params_speculative_draft`.
101#[derive(Debug, Clone, Copy, PartialEq)]
102pub struct Eagle3SessionConfig {
103 /// Number of concurrent sequences (usually `1`).
104 pub n_seq: u32,
105 /// Maximum tokens drafted per [`Eagle3Session::draft`] call (`n_max` upstream).
106 pub n_draft_max: i32,
107 /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
108 pub n_min: i32,
109 /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
110 pub p_min: f32,
111}
112
113impl Eagle3SessionConfig {
114 /// Build a config with upstream-aligned defaults for `n_min` (`0`) and
115 /// `p_min` (`0.0`).
116 #[must_use]
117 pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
118 Self {
119 n_seq,
120 n_draft_max,
121 n_min: 0,
122 p_min: 0.0,
123 }
124 }
125
126 /// Set minimum draft tokens (`n_min` upstream).
127 #[must_use]
128 pub fn with_n_min(mut self, n_min: i32) -> Self {
129 self.n_min = n_min;
130 self
131 }
132
133 /// Set draft probability floor (`p_min` upstream).
134 ///
135 /// Draft tokens whose greedy probability falls below this value are dropped.
136 #[must_use]
137 pub fn with_p_min(mut self, p_min: f32) -> Self {
138 self.p_min = p_min;
139 self
140 }
141}
142
143/// Owned EAGLE-3 draft session.
144///
145/// Drops the underlying speculative context when freed.
146///
147/// # Lifetime contract (manual)
148///
149/// The session holds raw pointers to both the target and draft
150/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
151/// drop them) for as long as the session exists.**
152pub struct Eagle3Session {
153 raw: NonNull<llama_cpp_sys_4::mtp_session>,
154 config: Eagle3SessionConfig,
155}
156
157// SAFETY: the underlying C++ session owns its own state and is not tied to any
158// TLS. Concurrent calls from multiple threads are NOT safe.
159unsafe impl Send for Eagle3Session {}
160
161impl Eagle3Session {
162 /// Construct an EAGLE-3 draft session with upstream defaults for `n_min`
163 /// and `p_min`.
164 ///
165 /// Equivalent to `new_with_config(target, draft, Eagle3SessionConfig::new(n_seq, n_draft_max))`.
166 ///
167 /// # Errors
168 ///
169 /// Returns [`Eagle3SessionError::Init`] or [`Eagle3SessionError::InvalidConfig`].
170 pub fn new(
171 target: &LlamaContext<'_>,
172 draft: &LlamaContext<'_>,
173 n_seq: u32,
174 n_draft_max: i32,
175 ) -> Result<Self, Eagle3SessionError> {
176 Self::new_with_config(target, draft, Eagle3SessionConfig::new(n_seq, n_draft_max))
177 }
178
179 /// Construct an EAGLE-3 draft session with full speculative draft
180 /// parameters.
181 ///
182 /// `target` must be a
183 /// [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default)
184 /// context over the main model. `draft` must be a `Default` context over a
185 /// **separate EAGLE-3 draft model** trained against that target.
186 ///
187 /// # Errors
188 ///
189 /// Returns [`Eagle3SessionError::Init`] (e.g. the draft model is not a
190 /// valid EAGLE-3 model) or [`Eagle3SessionError::InvalidConfig`].
191 pub fn new_with_config(
192 target: &LlamaContext<'_>,
193 draft: &LlamaContext<'_>,
194 config: Eagle3SessionConfig,
195 ) -> Result<Self, Eagle3SessionError> {
196 if config.n_seq == 0 {
197 return Err(Eagle3SessionError::InvalidConfig("n_seq must be > 0"));
198 }
199 if config.n_draft_max <= 0 {
200 return Err(Eagle3SessionError::InvalidConfig("n_draft_max must be > 0"));
201 }
202
203 let c_config = llama_cpp_sys_4::mtp_session_config {
204 n_seq: config.n_seq,
205 n_draft_max: config.n_draft_max,
206 n_min: config.n_min,
207 p_min: config.p_min,
208 spec_type: llama_cpp_sys_4::MTP_SPEC_TYPE_EAGLE3 as i32,
209 };
210
211 let raw = unsafe {
212 llama_cpp_sys_4::mtp_session_new(
213 target.context.as_ptr(),
214 draft.context.as_ptr(),
215 &raw const c_config,
216 )
217 };
218 let raw = NonNull::new(raw).ok_or(Eagle3SessionError::Init)?;
219 Ok(Self { raw, config })
220 }
221
222 /// Session configuration passed at construction.
223 #[must_use]
224 pub fn config(&self) -> Eagle3SessionConfig {
225 self.config
226 }
227
228 /// True when the speculative backend needs post-norm embeddings on the
229 /// target context (`llama_set_embeddings`).
230 #[must_use]
231 pub fn need_embd(&self) -> bool {
232 unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
233 }
234
235 /// True when the speculative backend needs pre-norm hidden states on the
236 /// target context (`llama_set_embeddings_pre_norm`).
237 ///
238 /// Configured automatically during session init; callers normally do not
239 /// need to set it manually.
240 #[must_use]
241 pub fn need_embd_pre_norm(&self) -> bool {
242 unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
243 }
244
245 /// Configured maximum number of tokens drafted per [`draft`](Self::draft) call.
246 #[must_use]
247 pub fn n_draft_max(&self) -> i32 {
248 self.config.n_draft_max
249 }
250
251 /// Configured minimum draft tokens (`n_min`).
252 #[must_use]
253 pub fn n_min(&self) -> i32 {
254 self.config.n_min
255 }
256
257 /// Configured draft probability floor (`p_min`).
258 #[must_use]
259 pub fn p_min(&self) -> f32 {
260 self.config.p_min
261 }
262
263 /// Configured number of sequences.
264 #[must_use]
265 pub fn n_seq(&self) -> u32 {
266 self.config.n_seq
267 }
268
269 /// Log speculative-decoding statistics (draft/accept counts and timings)
270 /// via llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`]
271 /// to capture output.
272 pub fn print_stats(&self) {
273 unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
274 }
275
276 /// Optional: call once at the start of a fresh generation with the prompt
277 /// tokens that were just decoded into the target context.
278 ///
279 /// # Errors
280 ///
281 /// Returns [`Eagle3SessionError::BadSeqId`] if `seq_id` is out of range.
282 pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), Eagle3SessionError> {
283 self.check_seq(seq_id)?;
284 unsafe {
285 llama_cpp_sys_4::mtp_session_begin(
286 self.raw.as_ptr(),
287 seq_id,
288 prompt.as_ptr().cast(),
289 prompt.len(),
290 );
291 }
292 Ok(())
293 }
294
295 /// Hand the session a batch that was just decoded on the target context.
296 ///
297 /// Call this after every successful `target.decode(batch)` so upstream can
298 /// harvest the target hidden states EAGLE-3 drafts from.
299 ///
300 /// # Errors
301 ///
302 /// Returns [`Eagle3SessionError::Process`] if the underlying call fails.
303 pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), Eagle3SessionError> {
304 let ok =
305 unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
306 if ok {
307 Ok(())
308 } else {
309 Err(Eagle3SessionError::Process)
310 }
311 }
312
313 /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
314 ///
315 /// `n_past` is the number of tokens already in the target KV cache for
316 /// `seq_id`. `id_last` is the last token accepted on the target (usually
317 /// the token you just sampled).
318 ///
319 /// # Errors
320 ///
321 /// Returns [`Eagle3SessionError::BadSeqId`] if `seq_id` is out of range.
322 pub fn draft(
323 &mut self,
324 seq_id: i32,
325 n_past: i32,
326 id_last: LlamaToken,
327 ) -> Result<Vec<LlamaToken>, Eagle3SessionError> {
328 self.check_seq(seq_id)?;
329
330 let cap = self.config.n_draft_max.max(0) as usize;
331 let mut buf: Vec<i32> = vec![0; cap];
332 let mut out_n: i32 = cap as i32;
333
334 unsafe {
335 llama_cpp_sys_4::mtp_session_draft(
336 self.raw.as_ptr(),
337 seq_id,
338 n_past,
339 id_last.0,
340 buf.as_mut_ptr(),
341 &mut out_n,
342 );
343 }
344
345 let n = out_n.max(0) as usize;
346 buf.truncate(n);
347 Ok(buf.into_iter().map(LlamaToken).collect())
348 }
349
350 /// Inform the session how many draft tokens the target verifier accepted.
351 ///
352 /// Pass `0` when every draft was rejected.
353 ///
354 /// # Errors
355 ///
356 /// Returns [`Eagle3SessionError::BadSeqId`] if `seq_id` is out of range.
357 pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), Eagle3SessionError> {
358 self.check_seq(seq_id)?;
359 unsafe {
360 llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
361 }
362 Ok(())
363 }
364
365 fn check_seq(&self, seq_id: i32) -> Result<(), Eagle3SessionError> {
366 if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
367 return Err(Eagle3SessionError::BadSeqId {
368 seq_id,
369 n_seq: self.config.n_seq,
370 });
371 }
372 Ok(())
373 }
374}
375
376impl Drop for Eagle3Session {
377 fn drop(&mut self) {
378 unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
379 }
380}
381
382impl std::fmt::Debug for Eagle3Session {
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 f.debug_struct("Eagle3Session")
385 .field("config", &self.config)
386 .finish()
387 }
388}