Skip to main content

llama_cpp_4/
eagle.rs

1//! Safe wrapper around the C++ EAGLE-3 draft session.
2//!
3//! [`Eagle3Session`] drives **EAGLE-3** speculative decoding
4//! (`COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3` in upstream llama.cpp). EAGLE-3
5//! pairs a target model with a small, separately-trained **EAGLE-3 draft
6//! model** that predicts the next tokens from hidden states extracted out of
7//! the target model.
8//!
9//! The draft algorithm lives in upstream's `common/speculative.cpp`
10//! (`common_speculative_impl_draft_eagle3`). This module wraps it through the
11//! same stable C shim used for MTP (`llama-cpp-sys-4/mtp_shim/`); the two
12//! techniques share an identical session lifecycle and differ only in how the
13//! draft context is built.
14//!
15//! # EAGLE-3 vs MTP
16//!
17//! | | EAGLE-3 ([`Eagle3Session`]) | MTP ([`crate::mtp::MtpSession`]) |
18//! |---|---|---|
19//! | Draft weights | a **separate** EAGLE-3 draft model | the **same** model as the target |
20//! | Draft context type | [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default) | [`LlamaContextType::Mtp`](crate::context::params::LlamaContextType::Mtp) |
21//! | Requirement | draft model must expose 3 target-extract layers | target model must have MTP heads |
22//!
23//! # Setup
24//!
25//! ```ignore
26//! use llama_cpp_4::context::params::LlamaContextParams;
27//! use llama_cpp_4::eagle::{Eagle3Session, Eagle3SessionConfig};
28//!
29//! let n_draft_max = 3;
30//!
31//! // Target: the main model, a normal (default) context.
32//! let target = main_model.new_context(&backend, LlamaContextParams::default())?;
33//!
34//! // Draft: a SEPARATE EAGLE-3 draft model, also a default context.
35//! let draft = eagle3_model.new_context(&backend, LlamaContextParams::default())?;
36//!
37//! let config = Eagle3SessionConfig::new(1, n_draft_max);
38//! let mut session = Eagle3Session::new_with_config(&target, &draft, config)?;
39//! ```
40//!
41//! # Speculative loop
42//!
43//! Identical in shape to MTP: after each decode on the **target** context call
44//! [`process`](Eagle3Session::process), then [`draft`](Eagle3Session::draft)
45//! to get candidate tokens, verify them on the target, and report how many
46//! were accepted with [`accept`](Eagle3Session::accept).
47//!
48//! ```ignore
49//! target.decode(&mut batch)?;
50//! session.process(&batch)?;
51//! let drafts = session.draft(0, n_past, last_token)?;
52//! // verify `drafts` against the target, count acceptances ...
53//! session.accept(0, n_accepted)?;
54//! ```
55//!
56//! # Hidden-state extraction
57//!
58//! EAGLE-3 needs the target model to expose internal hidden states. The
59//! session configures the required extraction on both contexts at construction
60//! time; [`need_embd`](Eagle3Session::need_embd) and
61//! [`need_embd_pre_norm`](Eagle3Session::need_embd_pre_norm) report which kind
62//! the active backend requested (rarely needed by callers).
63
64use std::ptr::NonNull;
65
66use crate::context::LlamaContext;
67use crate::llama_batch::LlamaBatch;
68use crate::token::LlamaToken;
69
70/// Errors raised by the EAGLE-3 draft session.
71#[derive(Debug, thiserror::Error)]
72pub enum Eagle3SessionError {
73    /// Returned when session init fails. The most common cause is that `draft`
74    /// was not built from a valid EAGLE-3 draft model (upstream expects a draft
75    /// model exposing exactly 3 target-extract layers), or that one of the
76    /// contexts is incompatible.
77    #[error("failed to create EAGLE-3 draft session — check that `draft` is a context over a valid EAGLE-3 draft model (3 extract layers) built from the same target")]
78    Init,
79
80    /// `process` returned false on the underlying speculative context.
81    #[error("EAGLE-3 process failed (see llama.cpp logs)")]
82    Process,
83
84    /// Caller passed a sequence id outside `[0, n_seq)`.
85    #[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
86    BadSeqId {
87        /// the offending seq id
88        seq_id: i32,
89        /// configured number of sequences
90        n_seq: u32,
91    },
92
93    /// Invalid session configuration (e.g. `n_draft_max <= 0`).
94    #[error("invalid EAGLE-3 session config: {0}")]
95    InvalidConfig(&'static str),
96}
97
98/// Parameters for [`Eagle3Session::new_with_config`].
99///
100/// Maps directly to upstream `common_params_speculative_draft`.
101#[derive(Debug, Clone, Copy, PartialEq)]
102pub struct Eagle3SessionConfig {
103    /// Number of concurrent sequences (usually `1`).
104    pub n_seq: u32,
105    /// Maximum tokens drafted per [`Eagle3Session::draft`] call (`n_max` upstream).
106    pub n_draft_max: i32,
107    /// Minimum draft tokens to propose (`n_min` upstream, default `0`).
108    pub n_min: i32,
109    /// Greedy probability floor; drafts below this are dropped (`p_min` upstream, default `0.0`).
110    pub p_min: f32,
111}
112
113impl Eagle3SessionConfig {
114    /// Build a config with upstream-aligned defaults for `n_min` (`0`) and
115    /// `p_min` (`0.0`).
116    #[must_use]
117    pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
118        Self {
119            n_seq,
120            n_draft_max,
121            n_min: 0,
122            p_min: 0.0,
123        }
124    }
125
126    /// Set minimum draft tokens (`n_min` upstream).
127    #[must_use]
128    pub fn with_n_min(mut self, n_min: i32) -> Self {
129        self.n_min = n_min;
130        self
131    }
132
133    /// Set draft probability floor (`p_min` upstream).
134    ///
135    /// Draft tokens whose greedy probability falls below this value are dropped.
136    #[must_use]
137    pub fn with_p_min(mut self, p_min: f32) -> Self {
138        self.p_min = p_min;
139        self
140    }
141}
142
143/// Owned EAGLE-3 draft session.
144///
145/// Drops the underlying speculative context when freed.
146///
147/// # Lifetime contract (manual)
148///
149/// The session holds raw pointers to both the target and draft
150/// [`LlamaContext`]s. **The caller must keep both contexts alive (i.e. not
151/// drop them) for as long as the session exists.**
152pub struct Eagle3Session {
153    raw: NonNull<llama_cpp_sys_4::mtp_session>,
154    config: Eagle3SessionConfig,
155}
156
157// SAFETY: the underlying C++ session owns its own state and is not tied to any
158// TLS. Concurrent calls from multiple threads are NOT safe.
159unsafe impl Send for Eagle3Session {}
160
161impl Eagle3Session {
162    /// Construct an EAGLE-3 draft session with upstream defaults for `n_min`
163    /// and `p_min`.
164    ///
165    /// Equivalent to `new_with_config(target, draft, Eagle3SessionConfig::new(n_seq, n_draft_max))`.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`Eagle3SessionError::Init`] or [`Eagle3SessionError::InvalidConfig`].
170    pub fn new(
171        target: &LlamaContext<'_>,
172        draft: &LlamaContext<'_>,
173        n_seq: u32,
174        n_draft_max: i32,
175    ) -> Result<Self, Eagle3SessionError> {
176        Self::new_with_config(target, draft, Eagle3SessionConfig::new(n_seq, n_draft_max))
177    }
178
179    /// Construct an EAGLE-3 draft session with full speculative draft
180    /// parameters.
181    ///
182    /// `target` must be a
183    /// [`LlamaContextType::Default`](crate::context::params::LlamaContextType::Default)
184    /// context over the main model. `draft` must be a `Default` context over a
185    /// **separate EAGLE-3 draft model** trained against that target.
186    ///
187    /// # Errors
188    ///
189    /// Returns [`Eagle3SessionError::Init`] (e.g. the draft model is not a
190    /// valid EAGLE-3 model) or [`Eagle3SessionError::InvalidConfig`].
191    pub fn new_with_config(
192        target: &LlamaContext<'_>,
193        draft: &LlamaContext<'_>,
194        config: Eagle3SessionConfig,
195    ) -> Result<Self, Eagle3SessionError> {
196        if config.n_seq == 0 {
197            return Err(Eagle3SessionError::InvalidConfig("n_seq must be > 0"));
198        }
199        if config.n_draft_max <= 0 {
200            return Err(Eagle3SessionError::InvalidConfig("n_draft_max must be > 0"));
201        }
202
203        let c_config = llama_cpp_sys_4::mtp_session_config {
204            n_seq: config.n_seq,
205            n_draft_max: config.n_draft_max,
206            n_min: config.n_min,
207            p_min: config.p_min,
208            spec_type: llama_cpp_sys_4::MTP_SPEC_TYPE_EAGLE3 as i32,
209        };
210
211        let raw = unsafe {
212            llama_cpp_sys_4::mtp_session_new(
213                target.context.as_ptr(),
214                draft.context.as_ptr(),
215                &raw const c_config,
216            )
217        };
218        let raw = NonNull::new(raw).ok_or(Eagle3SessionError::Init)?;
219        Ok(Self { raw, config })
220    }
221
222    /// Session configuration passed at construction.
223    #[must_use]
224    pub fn config(&self) -> Eagle3SessionConfig {
225        self.config
226    }
227
228    /// True when the speculative backend needs post-norm embeddings on the
229    /// target context (`llama_set_embeddings`).
230    #[must_use]
231    pub fn need_embd(&self) -> bool {
232        unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
233    }
234
235    /// True when the speculative backend needs pre-norm hidden states on the
236    /// target context (`llama_set_embeddings_pre_norm`).
237    ///
238    /// Configured automatically during session init; callers normally do not
239    /// need to set it manually.
240    #[must_use]
241    pub fn need_embd_pre_norm(&self) -> bool {
242        unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
243    }
244
245    /// Configured maximum number of tokens drafted per [`draft`](Self::draft) call.
246    #[must_use]
247    pub fn n_draft_max(&self) -> i32 {
248        self.config.n_draft_max
249    }
250
251    /// Configured minimum draft tokens (`n_min`).
252    #[must_use]
253    pub fn n_min(&self) -> i32 {
254        self.config.n_min
255    }
256
257    /// Configured draft probability floor (`p_min`).
258    #[must_use]
259    pub fn p_min(&self) -> f32 {
260        self.config.p_min
261    }
262
263    /// Configured number of sequences.
264    #[must_use]
265    pub fn n_seq(&self) -> u32 {
266        self.config.n_seq
267    }
268
269    /// Log speculative-decoding statistics (draft/accept counts and timings)
270    /// via llama.cpp `LOG_INF`. Install a log callback with [`crate::log_set`]
271    /// to capture output.
272    pub fn print_stats(&self) {
273        unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
274    }
275
276    /// Optional: call once at the start of a fresh generation with the prompt
277    /// tokens that were just decoded into the target context.
278    ///
279    /// # Errors
280    ///
281    /// Returns [`Eagle3SessionError::BadSeqId`] if `seq_id` is out of range.
282    pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), Eagle3SessionError> {
283        self.check_seq(seq_id)?;
284        unsafe {
285            llama_cpp_sys_4::mtp_session_begin(
286                self.raw.as_ptr(),
287                seq_id,
288                prompt.as_ptr().cast(),
289                prompt.len(),
290            );
291        }
292        Ok(())
293    }
294
295    /// Hand the session a batch that was just decoded on the target context.
296    ///
297    /// Call this after every successful `target.decode(batch)` so upstream can
298    /// harvest the target hidden states EAGLE-3 drafts from.
299    ///
300    /// # Errors
301    ///
302    /// Returns [`Eagle3SessionError::Process`] if the underlying call fails.
303    pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), Eagle3SessionError> {
304        let ok =
305            unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
306        if ok {
307            Ok(())
308        } else {
309            Err(Eagle3SessionError::Process)
310        }
311    }
312
313    /// Generate up to [`n_draft_max`](Self::n_draft_max) speculative tokens.
314    ///
315    /// `n_past` is the number of tokens already in the target KV cache for
316    /// `seq_id`. `id_last` is the last token accepted on the target (usually
317    /// the token you just sampled).
318    ///
319    /// # Errors
320    ///
321    /// Returns [`Eagle3SessionError::BadSeqId`] if `seq_id` is out of range.
322    pub fn draft(
323        &mut self,
324        seq_id: i32,
325        n_past: i32,
326        id_last: LlamaToken,
327    ) -> Result<Vec<LlamaToken>, Eagle3SessionError> {
328        self.check_seq(seq_id)?;
329
330        let cap = self.config.n_draft_max.max(0) as usize;
331        let mut buf: Vec<i32> = vec![0; cap];
332        let mut out_n: i32 = cap as i32;
333
334        unsafe {
335            llama_cpp_sys_4::mtp_session_draft(
336                self.raw.as_ptr(),
337                seq_id,
338                n_past,
339                id_last.0,
340                buf.as_mut_ptr(),
341                &mut out_n,
342            );
343        }
344
345        let n = out_n.max(0) as usize;
346        buf.truncate(n);
347        Ok(buf.into_iter().map(LlamaToken).collect())
348    }
349
350    /// Inform the session how many draft tokens the target verifier accepted.
351    ///
352    /// Pass `0` when every draft was rejected.
353    ///
354    /// # Errors
355    ///
356    /// Returns [`Eagle3SessionError::BadSeqId`] if `seq_id` is out of range.
357    pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), Eagle3SessionError> {
358        self.check_seq(seq_id)?;
359        unsafe {
360            llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
361        }
362        Ok(())
363    }
364
365    fn check_seq(&self, seq_id: i32) -> Result<(), Eagle3SessionError> {
366        if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
367            return Err(Eagle3SessionError::BadSeqId {
368                seq_id,
369                n_seq: self.config.n_seq,
370            });
371        }
372        Ok(())
373    }
374}
375
376impl Drop for Eagle3Session {
377    fn drop(&mut self) {
378        unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
379    }
380}
381
382impl std::fmt::Debug for Eagle3Session {
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        f.debug_struct("Eagle3Session")
385            .field("config", &self.config)
386            .finish()
387    }
388}