#![allow(missing_docs)]
use crate::model::TreeDecoder;
use crate::{Error, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor, D};
use candle_nn::{linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct EagleDraftConfig {
pub hidden_size: usize,
pub vocab_size: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub intermediate_size: usize,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub max_position_embeddings: usize,
}
impl EagleDraftConfig {
pub fn eagle_llama3_8b() -> Self {
Self {
hidden_size: 4096,
vocab_size: 128256,
num_attention_heads: 32,
num_key_value_heads: 8,
intermediate_size: 14336,
rms_norm_eps: 1e-5,
rope_theta: 500_000.0, max_position_embeddings: 2048,
}
}
pub fn eagle_llama2_chat_7b() -> Self {
Self {
hidden_size: 4096,
vocab_size: 32_000,
num_attention_heads: 32,
num_key_value_heads: 32,
intermediate_size: 11_008,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
max_position_embeddings: 4096,
}
}
fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
}
#[derive(Debug, Clone)]
struct DraftAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
cos: Tensor,
sin: Tensor,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl DraftAttention {
fn load(
cfg: &EagleDraftConfig,
vb: VarBuilder<'_>,
dev: &Device,
dtype: DType,
) -> Result<Self> {
let h = cfg.hidden_size;
let n = cfg.num_attention_heads;
let n_kv = cfg.num_key_value_heads;
let head_dim = cfg.head_dim();
let q_proj = linear_no_bias(h, n * head_dim, vb.pp("q_proj")).map_err(Error::Candle)?;
let k_proj = linear_no_bias(h, n_kv * head_dim, vb.pp("k_proj")).map_err(Error::Candle)?;
let v_proj = linear_no_bias(h, n_kv * head_dim, vb.pp("v_proj")).map_err(Error::Candle)?;
let o_proj = linear_no_bias(n * head_dim, h, vb.pp("o_proj")).map_err(Error::Candle)?;
let inv_freq: Vec<f32> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.collect();
let inv_freq_t = Tensor::from_vec(inv_freq.clone(), (1, inv_freq.len()), dev)
.map_err(Error::Candle)?
.to_dtype(dtype)
.map_err(Error::Candle)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)
.map_err(Error::Candle)?
.to_dtype(dtype)
.map_err(Error::Candle)?
.reshape((cfg.max_position_embeddings, 1))
.map_err(Error::Candle)?;
let freqs = t.matmul(&inv_freq_t).map_err(Error::Candle)?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
cos: freqs.cos().map_err(Error::Candle)?,
sin: freqs.sin().map_err(Error::Candle)?,
n_head: n,
n_kv_head: n_kv,
head_dim,
kv_cache: None,
})
}
fn forward(&mut self, xs: &Tensor, position: usize) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3().map_err(Error::Candle)?;
let q = self
.q_proj
.forward(xs)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_head, self.head_dim))
.map_err(Error::Candle)?
.transpose(1, 2)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let k = self
.k_proj
.forward(xs)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_kv_head, self.head_dim))
.map_err(Error::Candle)?
.transpose(1, 2)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let v = self
.v_proj
.forward(xs)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_kv_head, self.head_dim))
.map_err(Error::Candle)?
.transpose(1, 2)
.map_err(Error::Candle)?;
let cos = self.cos.narrow(0, position, q_len).map_err(Error::Candle)?;
let sin = self.sin.narrow(0, position, q_len).map_err(Error::Candle)?;
let q = candle_nn::rotary_emb::rope(&q, &cos, &sin).map_err(Error::Candle)?;
let k = candle_nn::rotary_emb::rope(&k, &cos, &sin).map_err(Error::Candle)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((pk, pv)) => (
Tensor::cat(&[pk, &k], 2).map_err(Error::Candle)?,
Tensor::cat(&[pv, &v], 2).map_err(Error::Candle)?,
),
};
self.kv_cache = Some((k.clone(), v.clone()));
let n_rep = self.n_head / self.n_kv_head;
let k = candle_transformers::utils::repeat_kv(k, n_rep)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let v = candle_transformers::utils::repeat_kv(v, n_rep)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let scale = 1f64 / (self.head_dim as f64).sqrt();
let attn = (q
.matmul(&k.t().map_err(Error::Candle)?)
.map_err(Error::Candle)?
* scale)
.map_err(Error::Candle)?;
let prev_len = self
.kv_cache
.as_ref()
.map(|(k, _)| k.dim(2).unwrap_or(0))
.unwrap_or(0)
- q_len;
let attn = if q_len <= 1 {
attn
} else {
let total = prev_len + q_len;
let mut data = vec![0f32; q_len * total];
for i in 0..q_len {
for j in 0..q_len {
if j > i {
data[i * total + prev_len + j] = f32::NEG_INFINITY;
}
}
}
let bias = Tensor::from_slice(&data, (q_len, total), xs.device())
.map_err(Error::Candle)?
.to_dtype(xs.dtype())
.map_err(Error::Candle)?
.reshape((1, 1, q_len, total))
.map_err(Error::Candle)?;
attn.broadcast_add(&bias).map_err(Error::Candle)?
};
let attn = candle_nn::ops::softmax_last_dim(&attn).map_err(Error::Candle)?;
let y = attn.matmul(&v).map_err(Error::Candle)?;
let y = y
.transpose(1, 2)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_head * self.head_dim))
.map_err(Error::Candle)?;
self.o_proj.forward(&y).map_err(Error::Candle)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None;
}
}
#[derive(Debug, Clone)]
struct DraftMlp {
gate: Linear,
up: Linear,
down: Linear,
}
impl DraftMlp {
fn load(cfg: &EagleDraftConfig, vb: VarBuilder<'_>) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
Ok(Self {
gate: linear_no_bias(h, i, vb.pp("gate_proj")).map_err(Error::Candle)?,
up: linear_no_bias(h, i, vb.pp("up_proj")).map_err(Error::Candle)?,
down: linear_no_bias(i, h, vb.pp("down_proj")).map_err(Error::Candle)?,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let g = candle_nn::ops::silu(&self.gate.forward(xs).map_err(Error::Candle)?)
.map_err(Error::Candle)?;
let u = self.up.forward(xs).map_err(Error::Candle)?;
self.down
.forward(&(g * u).map_err(Error::Candle)?)
.map_err(Error::Candle)
}
}
pub struct EagleDraftCandle {
config: EagleDraftConfig,
embed_tokens: Embedding,
fc: Linear,
attn: DraftAttention,
post_attention_layernorm: RmsNorm,
mlp: DraftMlp,
}
impl std::fmt::Debug for EagleDraftCandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EagleDraftCandle")
.field("hidden_size", &self.config.hidden_size)
.field("vocab_size", &self.config.vocab_size)
.finish()
}
}
impl EagleDraftCandle {
pub fn config(&self) -> &EagleDraftConfig {
&self.config
}
pub fn from_pth(
config: &EagleDraftConfig,
path: impl AsRef<Path>,
device: &Device,
dtype: DType,
) -> Result<Self> {
let vb = VarBuilder::from_pth(path.as_ref(), dtype, device).map_err(Error::Candle)?;
Self::from_var_builder(config, vb, device, dtype)
}
pub fn from_var_builder(
config: &EagleDraftConfig,
vb: VarBuilder<'_>,
device: &Device,
dtype: DType,
) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(config.vocab_size, config.hidden_size, vb.pp("embed_tokens"))
.map_err(Error::Candle)?;
let fc = linear_no_bias(2 * config.hidden_size, config.hidden_size, vb.pp("fc"))
.map_err(Error::Candle)?;
let attn = DraftAttention::load(config, vb.pp("layers.0.self_attn"), device, dtype)?;
let post_attention_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("layers.0.post_attention_layernorm"),
)
.map_err(Error::Candle)?;
let mlp = DraftMlp::load(config, vb.pp("layers.0.mlp"))?;
Ok(Self {
config: config.clone(),
embed_tokens,
fc,
attn,
post_attention_layernorm,
mlp,
})
}
pub fn reset(&mut self) {
self.attn.clear_kv_cache();
}
pub fn forward(
&mut self,
target_hidden: &Tensor,
token_ids: &Tensor,
position: usize,
) -> Result<Tensor> {
let token_emb = self
.embed_tokens
.forward(token_ids)
.map_err(Error::Candle)?;
let target_hidden_owned;
let target_hidden_use: &Tensor = if target_hidden.dtype() != token_emb.dtype() {
target_hidden_owned = target_hidden
.to_dtype(token_emb.dtype())
.map_err(Error::Candle)?;
&target_hidden_owned
} else {
target_hidden
};
let combined =
Tensor::cat(&[target_hidden_use, &token_emb], D::Minus1).map_err(Error::Candle)?;
let xs = self.fc.forward(&combined).map_err(Error::Candle)?;
let res = xs;
let attn = self.attn.forward(&res, position)?;
let xs = (attn + &res).map_err(Error::Candle)?;
let res = &xs;
let xs_n = self
.post_attention_layernorm
.forward(&xs)
.map_err(Error::Candle)?;
let m = self.mlp.forward(&xs_n)?;
(m + res).map_err(Error::Candle)
}
}
#[derive(Debug, Clone)]
pub struct EagleRunConfig {
pub top_k_per_step: usize,
pub draft_depth: usize,
pub max_tree_nodes: Option<usize>,
pub strict_root_gemv: bool,
pub temperature: f32,
pub top_p: f32,
}
impl Default for EagleRunConfig {
fn default() -> Self {
Self {
top_k_per_step: 2,
draft_depth: 4,
max_tree_nodes: None,
strict_root_gemv: false,
temperature: 0.0, top_p: 1.0,
}
}
}
pub fn run_eagle<T, R>(
target: &mut T,
draft: &mut EagleDraftCandle,
prompt: &[u32],
max_new_tokens: usize,
config: &EagleRunConfig,
rng: &mut R,
) -> Result<Vec<u32>>
where
T: TreeDecoder + ?Sized,
R: rand::Rng + ?Sized,
{
use crate::methods::medusa::top_k_indices;
target.reset();
if prompt.is_empty() {
return Err(Error::Sampling("EAGLE requires non-empty prompt".into()));
}
let mut target_hidden = target.observe_returning_last_hidden(prompt)?;
let mut generated = Vec::with_capacity(max_new_tokens);
while generated.len() < max_new_tokens {
let root_token = *target
.history()
.last()
.ok_or_else(|| Error::Sampling("EAGLE requires non-empty prompt".into()))?;
let hidden_reshaped = target_hidden
.unsqueeze(0)
.map_err(Error::Candle)?
.unsqueeze(0)
.map_err(Error::Candle)?;
draft.reset();
let history_len = target.history_len();
let mut per_step_top_k: Vec<Vec<u32>> = Vec::with_capacity(config.draft_depth);
let mut per_step_top_k_log_probs: Vec<Vec<f32>> = Vec::with_capacity(config.draft_depth);
let mut current_hidden = hidden_reshaped;
let mut current_token_ids =
Tensor::from_slice(&[root_token], (1, 1), target_hidden.device())
.map_err(Error::Candle)?;
for step in 0..config.draft_depth {
let draft_hidden =
draft.forward(¤t_hidden, ¤t_token_ids, history_len + step)?;
let logits = target.apply_lm_head(&draft_hidden)?;
let last = logits
.i((0, draft_hidden.dim(1).map_err(Error::Candle)? - 1, ..))
.map_err(Error::Candle)?
.to_dtype(DType::F32)
.map_err(Error::Candle)?
.to_vec1::<f32>()
.map_err(Error::Candle)?;
let top_idx: Vec<usize> = top_k_indices(&last, config.top_k_per_step);
let max_l = last.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let lse = last.iter().map(|&v| (v - max_l).exp()).sum::<f32>().ln() + max_l;
let top_log_probs: Vec<f32> = top_idx.iter().map(|&i| last[i] - lse).collect();
let top: Vec<u32> = top_idx.iter().map(|&i| i as u32).collect();
per_step_top_k.push(top.clone());
per_step_top_k_log_probs.push(top_log_probs);
let next_id = top[0];
current_token_ids = Tensor::from_slice(&[next_id], (1, 1), target_hidden.device())
.map_err(Error::Candle)?;
current_hidden = draft_hidden;
}
let full_tree = crate::methods::medusa::MedusaHeads::from_config(
crate::methods::medusa::MedusaConfig {
n_heads: config.draft_depth,
hidden_size: draft.config.hidden_size,
vocab_size: draft.config.vocab_size,
residual_layers: 1,
},
)
.build_draft_tree(
root_token,
&per_step_top_k,
crate::methods::medusa::TreeTopology::CartesianProduct,
)?;
let tree = if let Some(max_n) = config.max_tree_nodes {
prune_cartesian_tree(&full_tree, &per_step_top_k_log_probs, max_n)?
} else {
full_tree
};
let strict_root_logits: Option<Vec<f32>> = if config.strict_root_gemv {
Some(target.next_logits()?)
} else {
None
};
let (mut per_node_logits, _per_node_hidden) =
target.tree_logits_keep_kv(&tree)?;
if let Some(root_gemv) = strict_root_logits {
per_node_logits[0] = root_gemv;
}
let mut best_path: Vec<usize> = vec![0];
for path in tree.paths() {
let mut accepted = 0;
for w in path.windows(2) {
let parent = w[0];
let child = w[1];
let candidate = tree.token_at(child) as usize;
let parent_dist = &per_node_logits[parent];
let argmax = parent_dist
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv {
(i, v)
} else {
(bi, bv)
}
})
.0;
if argmax == candidate {
accepted += 1;
} else {
break;
}
}
if accepted + 1 > best_path.len() {
best_path = path[..=accepted].to_vec();
}
}
let deepest_idx = *best_path.last().unwrap();
let bonus = per_node_logits[deepest_idx]
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv {
(i, v)
} else {
(bi, bv)
}
})
.0 as u32;
let mut committed: Vec<u32> = best_path
.iter()
.skip(1)
.map(|&i| tree.token_at(i))
.collect();
committed.push(bonus);
let eos_set = target.eos_token_ids();
let eos_pos = committed.iter().position(|t| eos_set.contains(t));
let stop = eos_pos.is_some();
if let Some(p) = eos_pos {
committed.truncate(p + 1);
}
let path_committed: Vec<u32> = best_path
.iter()
.skip(1)
.map(|&i| tree.token_at(i))
.collect();
let path_eos_index = path_committed.iter().position(|t| eos_set.contains(t));
if let Some(idx) = path_eos_index {
let kept_path: Vec<usize> = best_path[..=idx + 1].to_vec();
target.commit_tree_path(&tree, &kept_path)?;
} else {
let mut bonus_in_tree: Option<usize> = None;
if !config.strict_root_gemv {
for n in 1..tree.len() {
if tree.parent_of(n) == deepest_idx && tree.token_at(n) == bonus {
bonus_in_tree = Some(n);
break;
}
}
}
if let Some(bn) = bonus_in_tree {
let mut kept_path = best_path.clone();
kept_path.push(bn);
target.commit_tree_path(&tree, &kept_path)?;
target_hidden = _per_node_hidden[bn].clone();
} else {
target.commit_tree_path(&tree, &best_path)?;
target_hidden = target.observe_returning_last_hidden(&[bonus])?;
}
}
generated.extend_from_slice(&committed);
if stop {
break;
}
}
let _ = (rng, config.temperature, config.top_p);
generated.truncate(max_new_tokens);
Ok(generated)
}
pub fn prune_cartesian_tree_pub(
full: &crate::tree::DraftTree,
per_step_log_probs: &[Vec<f32>],
max_total_nodes: usize,
) -> Result<crate::tree::DraftTree> {
prune_cartesian_tree(full, per_step_log_probs, max_total_nodes)
}
fn prune_cartesian_tree(
full: &crate::tree::DraftTree,
per_step_log_probs: &[Vec<f32>],
max_total_nodes: usize,
) -> Result<crate::tree::DraftTree> {
use crate::tree::DraftTree;
if full.len() <= max_total_nodes {
return clone_tree(full);
}
let mut children_of: Vec<Vec<usize>> = vec![Vec::new(); full.len()];
for n in 1..full.len() {
let p = full.parent_of(n);
children_of[p].push(n);
}
let mut candidate_index_of = vec![0usize; full.len()];
for siblings in &children_of {
for (rank, &c) in siblings.iter().enumerate() {
candidate_index_of[c] = rank;
}
}
let mut scores: Vec<(usize, f32)> = Vec::with_capacity(full.len() - 1);
for n in 1..full.len() {
let depth = full.depth_of(n);
let mut s = 0f32;
let mut cur = n;
for d in (0..depth).rev() {
let cand = candidate_index_of[cur];
s += per_step_log_probs[d][cand];
cur = full.parent_of(cur);
}
scores.push((n, s));
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let want = max_total_nodes.saturating_sub(1).min(scores.len());
let mut keep = vec![false; full.len()];
keep[0] = true;
for &(n, _) in scores.iter().take(want) {
keep[n] = true;
}
for n in (1..full.len()).rev() {
if keep[n] {
keep[full.parent_of(n)] = true;
}
}
let mut order: Vec<usize> = (0..full.len()).filter(|&n| keep[n]).collect();
order.sort_by_key(|&n| full.depth_of(n));
let new_index: std::collections::HashMap<usize, usize> = order
.iter()
.enumerate()
.map(|(new_i, &old_i)| (old_i, new_i))
.collect();
let mut entries: Vec<(usize, u32)> = Vec::with_capacity(order.len());
for &old in &order {
let parent_old = if old == 0 { 0 } else { full.parent_of(old) };
let parent_new = *new_index.get(&parent_old).expect("ancestor present");
entries.push((parent_new, full.token_at(old)));
}
DraftTree::from_parent_table(&entries)
.map_err(|e| Error::Sampling(format!("pruned tree invalid: {e}")))
}
fn clone_tree(t: &crate::tree::DraftTree) -> Result<crate::tree::DraftTree> {
let entries: Vec<(usize, u32)> = (0..t.len())
.map(|i| {
let parent = if i == 0 { 0 } else { t.parent_of(i) };
(parent, t.token_at(i))
})
.collect();
crate::tree::DraftTree::from_parent_table(&entries)
.map_err(|e| Error::Sampling(format!("clone tree invalid: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_defaults_match_eagle_llama3_8b() {
let c = EagleDraftConfig::eagle_llama3_8b();
assert_eq!(c.hidden_size, 4096);
assert_eq!(c.vocab_size, 128256);
assert_eq!(c.num_attention_heads, 32);
assert_eq!(c.num_key_value_heads, 8);
assert_eq!(c.head_dim(), 128);
}
#[test]
fn config_clone() {
let c = EagleDraftConfig::eagle_llama3_8b();
let c2 = c.clone();
assert_eq!(c.hidden_size, c2.hidden_size);
}
#[test]
fn run_config_default_strict_off() {
let c = EagleRunConfig::default();
assert!(!c.strict_root_gemv, "fast mode is the default");
}
#[test]
fn run_config_strict_toggle_compiles() {
let c = EagleRunConfig {
strict_root_gemv: true,
..Default::default()
};
assert!(c.strict_root_gemv);
}
#[test]
fn prune_keeps_root_and_top_paths() {
let cart = crate::methods::medusa::MedusaHeads::from_config(
crate::methods::medusa::MedusaConfig {
n_heads: 2,
hidden_size: 4,
vocab_size: 100,
residual_layers: 1,
},
)
.build_draft_tree(
42, &[vec![10, 20], vec![30, 40]],
crate::methods::medusa::TreeTopology::CartesianProduct,
)
.expect("build cart");
assert_eq!(cart.len(), 7);
let log_probs = vec![vec![-0.1f32, -2.0], vec![-0.2, -3.0]];
let pruned = prune_cartesian_tree(&cart, &log_probs, 4).expect("prune");
assert!(pruned.len() <= 4, "pruned should be ≤ 4 nodes");
assert!(!pruned.is_empty());
assert_eq!(pruned.token_at(0), 42, "root preserved");
let tokens: Vec<u32> = (0..pruned.len()).map(|i| pruned.token_at(i)).collect();
assert!(tokens.contains(&10), "best layer-0 child kept");
assert!(tokens.contains(&30), "best layer-1 child kept");
}
#[test]
fn prune_returns_full_tree_when_under_limit() {
let cart = crate::methods::medusa::MedusaHeads::from_config(
crate::methods::medusa::MedusaConfig {
n_heads: 2,
hidden_size: 4,
vocab_size: 100,
residual_layers: 1,
},
)
.build_draft_tree(
1,
&[vec![2], vec![3]],
crate::methods::medusa::TreeTopology::CartesianProduct,
)
.expect("build");
assert_eq!(cart.len(), 3);
let pruned = prune_cartesian_tree(&cart, &[vec![-0.1], vec![-0.2]], 100).expect("prune");
assert_eq!(pruned.len(), 3);
}
}