#![allow(missing_docs)]
use crate::{Error, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor, D};
use candle_nn::{linear_no_bias, rms_norm, Linear, RmsNorm, VarBuilder};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct Eagle3DraftConfig {
pub hidden_size: usize,
pub draft_vocab_size: usize,
pub target_vocab_size: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub intermediate_size: usize,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub max_position_embeddings: usize,
}
impl Eagle3DraftConfig {
pub fn eagle3_llama3_1_8b() -> Self {
Self {
hidden_size: 4096,
draft_vocab_size: 32000,
target_vocab_size: 128256,
num_attention_heads: 32,
num_key_value_heads: 8,
intermediate_size: 14336,
rms_norm_eps: 1e-5,
rope_theta: 500_000.0,
max_position_embeddings: 2048,
}
}
fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
}
#[derive(Debug, Clone)]
struct Midlayer {
input_layernorm: RmsNorm,
hidden_norm: RmsNorm,
q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear, post_attention_layernorm: RmsNorm,
mlp_gate: Linear,
mlp_up: Linear,
mlp_down: Linear,
cos: Tensor,
sin: Tensor,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Midlayer {
fn load(
cfg: &Eagle3DraftConfig,
vb: VarBuilder<'_>,
dev: &Device,
dtype: DType,
) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
let n = cfg.num_attention_heads;
let n_kv = cfg.num_key_value_heads;
let head_dim = cfg.head_dim();
let input_layernorm =
rms_norm(h, cfg.rms_norm_eps, vb.pp("input_layernorm")).map_err(Error::Candle)?;
let hidden_norm =
rms_norm(h, cfg.rms_norm_eps, vb.pp("hidden_norm")).map_err(Error::Candle)?;
let q_proj = linear_no_bias(2 * h, n * head_dim, vb.pp("self_attn.q_proj"))
.map_err(Error::Candle)?;
let k_proj = linear_no_bias(2 * h, n_kv * head_dim, vb.pp("self_attn.k_proj"))
.map_err(Error::Candle)?;
let v_proj = linear_no_bias(2 * h, n_kv * head_dim, vb.pp("self_attn.v_proj"))
.map_err(Error::Candle)?;
let o_proj =
linear_no_bias(n * head_dim, h, vb.pp("self_attn.o_proj")).map_err(Error::Candle)?;
let post_attention_layernorm =
rms_norm(h, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"))
.map_err(Error::Candle)?;
let mlp_gate = linear_no_bias(h, i, vb.pp("mlp.gate_proj")).map_err(Error::Candle)?;
let mlp_up = linear_no_bias(h, i, vb.pp("mlp.up_proj")).map_err(Error::Candle)?;
let mlp_down = linear_no_bias(i, h, vb.pp("mlp.down_proj")).map_err(Error::Candle)?;
let inv_freq: Vec<f32> = (0..head_dim)
.step_by(2)
.map(|j| 1f32 / cfg.rope_theta.powf(j as f32 / head_dim as f32))
.collect();
let inv_freq_t = Tensor::from_vec(inv_freq.clone(), (1, inv_freq.len()), dev)
.map_err(Error::Candle)?
.to_dtype(dtype)
.map_err(Error::Candle)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)
.map_err(Error::Candle)?
.to_dtype(dtype)
.map_err(Error::Candle)?
.reshape((cfg.max_position_embeddings, 1))
.map_err(Error::Candle)?;
let freqs = t.matmul(&inv_freq_t).map_err(Error::Candle)?;
Ok(Self {
input_layernorm,
hidden_norm,
q_proj,
k_proj,
v_proj,
o_proj,
post_attention_layernorm,
mlp_gate,
mlp_up,
mlp_down,
cos: freqs.cos().map_err(Error::Candle)?,
sin: freqs.sin().map_err(Error::Candle)?,
n_head: n,
n_kv_head: n_kv,
head_dim,
kv_cache: None,
})
}
fn forward(
&mut self,
target_hidden: &Tensor,
projected_features: &Tensor,
position: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = target_hidden.dims3().map_err(Error::Candle)?;
let th = self
.input_layernorm
.forward(target_hidden)
.map_err(Error::Candle)?;
let pf = self
.hidden_norm
.forward(projected_features)
.map_err(Error::Candle)?;
let combined = Tensor::cat(&[&th, &pf], D::Minus1).map_err(Error::Candle)?;
let q = self
.q_proj
.forward(&combined)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_head, self.head_dim))
.map_err(Error::Candle)?
.transpose(1, 2)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let k = self
.k_proj
.forward(&combined)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_kv_head, self.head_dim))
.map_err(Error::Candle)?
.transpose(1, 2)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let v = self
.v_proj
.forward(&combined)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_kv_head, self.head_dim))
.map_err(Error::Candle)?
.transpose(1, 2)
.map_err(Error::Candle)?;
let cos = self.cos.narrow(0, position, q_len).map_err(Error::Candle)?;
let sin = self.sin.narrow(0, position, q_len).map_err(Error::Candle)?;
let q = candle_nn::rotary_emb::rope(&q, &cos, &sin).map_err(Error::Candle)?;
let k = candle_nn::rotary_emb::rope(&k, &cos, &sin).map_err(Error::Candle)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((pk, pv)) => (
Tensor::cat(&[pk, &k], 2).map_err(Error::Candle)?,
Tensor::cat(&[pv, &v], 2).map_err(Error::Candle)?,
),
};
self.kv_cache = Some((k.clone(), v.clone()));
let n_rep = self.n_head / self.n_kv_head;
let k = candle_transformers::utils::repeat_kv(k, n_rep)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let v = candle_transformers::utils::repeat_kv(v, n_rep)
.map_err(Error::Candle)?
.contiguous()
.map_err(Error::Candle)?;
let scale = 1f64 / (self.head_dim as f64).sqrt();
let attn = (q
.matmul(&k.t().map_err(Error::Candle)?)
.map_err(Error::Candle)?
* scale)
.map_err(Error::Candle)?;
let attn = candle_nn::ops::softmax_last_dim(&attn).map_err(Error::Candle)?;
let y = attn.matmul(&v).map_err(Error::Candle)?;
let y = y
.transpose(1, 2)
.map_err(Error::Candle)?
.reshape((b_sz, q_len, self.n_head * self.head_dim))
.map_err(Error::Candle)?;
let attn_out = self.o_proj.forward(&y).map_err(Error::Candle)?;
let after_attn = (attn_out + projected_features).map_err(Error::Candle)?;
let x_n = self
.post_attention_layernorm
.forward(&after_attn)
.map_err(Error::Candle)?;
let g = candle_nn::ops::silu(&self.mlp_gate.forward(&x_n).map_err(Error::Candle)?)
.map_err(Error::Candle)?;
let u = self.mlp_up.forward(&x_n).map_err(Error::Candle)?;
let m = self
.mlp_down
.forward(&(g * u).map_err(Error::Candle)?)
.map_err(Error::Candle)?;
(m + after_attn).map_err(Error::Candle)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None;
}
}
pub struct Eagle3DraftCandle {
config: Eagle3DraftConfig,
fc: Linear,
midlayer: Midlayer,
norm: RmsNorm,
lm_head: Linear,
d2t_host: Vec<i64>,
target_to_draft: std::collections::HashMap<u32, u32>,
t2d_mask: Vec<bool>,
}
impl std::fmt::Debug for Eagle3DraftCandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Eagle3DraftCandle")
.field("hidden_size", &self.config.hidden_size)
.field("draft_vocab_size", &self.config.draft_vocab_size)
.field("target_vocab_size", &self.config.target_vocab_size)
.finish()
}
}
impl Eagle3DraftCandle {
pub fn config(&self) -> &Eagle3DraftConfig {
&self.config
}
pub fn from_pth(
config: &Eagle3DraftConfig,
path: impl AsRef<Path>,
device: &Device,
dtype: DType,
) -> Result<Self> {
let vb = VarBuilder::from_pth(path.as_ref(), dtype, device).map_err(Error::Candle)?;
Self::from_var_builder(config, vb, device, dtype)
}
pub fn from_var_builder(
config: &Eagle3DraftConfig,
vb: VarBuilder<'_>,
device: &Device,
dtype: DType,
) -> Result<Self> {
let fc = linear_no_bias(3 * config.hidden_size, config.hidden_size, vb.pp("fc"))
.map_err(Error::Candle)?;
let midlayer = Midlayer::load(config, vb.pp("midlayer"), device, dtype)?;
let norm = rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("norm"))
.map_err(Error::Candle)?;
let lm_head = linear_no_bias(
config.hidden_size,
config.draft_vocab_size,
vb.pp("lm_head"),
)
.map_err(Error::Candle)?;
let d2t = vb
.get_with_hints_dtype(
config.draft_vocab_size,
"d2t",
Default::default(),
DType::I64,
)
.map_err(Error::Candle)?;
let d2t_host: Vec<i64> = d2t.to_vec1::<i64>().map_err(Error::Candle)?;
let mut target_to_draft = std::collections::HashMap::with_capacity(d2t_host.len());
let mut t2d_mask = vec![false; config.target_vocab_size];
for (draft_id, &target_id) in d2t_host.iter().enumerate() {
if target_id >= 0 && (target_id as u32) < config.target_vocab_size as u32 {
target_to_draft.insert(target_id as u32, draft_id as u32);
t2d_mask[target_id as usize] = true;
}
}
drop(d2t);
Ok(Self {
config: config.clone(),
fc,
midlayer,
norm,
lm_head,
d2t_host,
target_to_draft,
t2d_mask,
})
}
pub fn target_to_draft_token(&self, target_id: u32) -> Option<u32> {
self.target_to_draft.get(&target_id).copied()
}
pub fn reset(&mut self) {
self.midlayer.clear_kv_cache();
}
pub fn forward_hidden(
&mut self,
hidden_states: &Tensor,
input_emb: &Tensor,
position: usize,
) -> Result<Tensor> {
let inner = hidden_states.dims()[hidden_states.rank() - 1];
let projected = if inner == self.config.hidden_size {
hidden_states.clone()
} else {
self.fc.forward(hidden_states).map_err(Error::Candle)?
};
self.midlayer.forward(input_emb, &projected, position)
}
pub fn apply_norm_lm_head(&self, hidden: &Tensor) -> Result<Tensor> {
let h = self.norm.forward(hidden).map_err(Error::Candle)?;
self.lm_head.forward(&h).map_err(Error::Candle)
}
pub fn forward(
&mut self,
low: &Tensor,
mid: &Tensor,
high: &Tensor,
input_emb: &Tensor,
token_ids: &Tensor,
position: usize,
) -> Result<Tensor> {
let _ = token_ids; let combined = Tensor::cat(&[low, mid, high], D::Minus1).map_err(Error::Candle)?;
let h = self.forward_hidden(&combined, input_emb, position)?;
self.apply_norm_lm_head(&h)
}
pub fn draft_to_target_token(&self, draft_id: u32) -> Result<u32> {
let v = self.d2t_host.get(draft_id as usize).ok_or_else(|| {
Error::Sampling(format!(
"draft id {draft_id} out of range ({})",
self.d2t_host.len()
))
})?;
Ok(*v as u32)
}
pub fn target_token_is_reachable(&self, target_id: u32) -> Result<bool> {
Ok(self
.t2d_mask
.get(target_id as usize)
.copied()
.unwrap_or(false))
}
pub fn t2d_mask(&self) -> &[bool] {
&self.t2d_mask
}
pub fn mask_target_logits(&self, logits: &mut [f32]) {
debug_assert_eq!(logits.len(), self.t2d_mask.len());
for (i, slot) in logits.iter_mut().enumerate() {
if !self.t2d_mask[i] {
*slot = f32::NEG_INFINITY;
}
}
}
}
#[derive(Debug, Clone)]
pub struct Eagle3RunConfig {
pub top_k_per_step: usize,
pub draft_depth: usize,
pub max_tree_nodes: Option<usize>,
pub layer_indices: [usize; 3],
pub strict_root_gemv: bool,
pub temperature: f32,
pub top_p: f32,
}
impl Eagle3RunConfig {
pub fn default_layers_for(n_layers: usize) -> [usize; 3] {
if n_layers < 4 {
[0, 0, 0]
} else {
[1, n_layers / 2 - 1, n_layers.saturating_sub(4)]
}
}
}
impl Default for Eagle3RunConfig {
fn default() -> Self {
Self {
top_k_per_step: 2,
draft_depth: 4,
max_tree_nodes: None,
layer_indices: [1, 15, 28], strict_root_gemv: false,
temperature: 0.0,
top_p: 1.0,
}
}
}
pub fn run_eagle3<T, R>(
target: &mut T,
draft: &mut Eagle3DraftCandle,
prompt: &[u32],
max_new_tokens: usize,
config: &Eagle3RunConfig,
rng: &mut R,
) -> Result<Vec<u32>>
where
T: crate::model::TreeDecoder + ?Sized,
R: rand::Rng + ?Sized,
{
use crate::methods::medusa::top_k_indices;
target.reset();
target.observe(prompt)?;
let mut generated = Vec::with_capacity(max_new_tokens);
while generated.len() < max_new_tokens {
let root_token = *target
.history()
.last()
.ok_or_else(|| Error::Sampling("EAGLE-3 requires non-empty prompt".into()))?;
let root_draft_id = draft.target_to_draft_token(root_token).unwrap_or(0);
let (final_h, mids) = target.last_hidden_states_multi(&config.layer_indices)?;
if mids.len() != 3 {
return Err(Error::Sampling(format!(
"EAGLE-3 expects 3 layers, got {}",
mids.len()
)));
}
let draft_dtype = draft.fc.weight().dtype();
let device = final_h.device().clone();
let to_3d = |t: &candle_core::Tensor| -> Result<candle_core::Tensor> {
let t = if t.dtype() != draft_dtype {
t.to_dtype(draft_dtype).map_err(Error::Candle)?
} else {
t.clone()
};
t.unsqueeze(0)
.map_err(Error::Candle)?
.unsqueeze(0)
.map_err(Error::Candle)
};
let low = to_3d(&mids[0])?;
let mid = to_3d(&mids[1])?;
let high = to_3d(&mids[2])?;
let mut hidden_in =
Tensor::cat(&[&low, &mid, &high], D::Minus1).map_err(Error::Candle)?;
draft.reset();
let history_len = target.history_len();
let mut per_step_top_k_target: Vec<Vec<u32>> = Vec::with_capacity(config.draft_depth);
let mut per_step_top_k_log_probs: Vec<Vec<f32>> = Vec::with_capacity(config.draft_depth);
let mut current_target_id = root_token;
for step in 0..config.draft_depth {
let token_ids = Tensor::from_slice(&[current_target_id], (1, 1), &device)
.map_err(Error::Candle)?;
let input_emb = target.embed_tokens(&token_ids)?;
let input_emb = if input_emb.dtype() != draft_dtype {
input_emb.to_dtype(draft_dtype).map_err(Error::Candle)?
} else {
input_emb
};
let out_hidden = draft.forward_hidden(&hidden_in, &input_emb, history_len + step)?;
let logits = draft.apply_norm_lm_head(&out_hidden)?;
let last = logits
.i((0, logits.dim(1).map_err(Error::Candle)? - 1, ..))
.map_err(Error::Candle)?
.to_dtype(DType::F32)
.map_err(Error::Candle)?
.to_vec1::<f32>()
.map_err(Error::Candle)?;
let top_idx: Vec<usize> = top_k_indices(&last, config.top_k_per_step);
let max_l = last.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let lse = last.iter().map(|&v| (v - max_l).exp()).sum::<f32>().ln() + max_l;
let log_probs: Vec<f32> = top_idx.iter().map(|&i| last[i] - lse).collect();
per_step_top_k_log_probs.push(log_probs);
let mut top_target = Vec::with_capacity(top_idx.len());
for &di in &top_idx {
top_target.push(draft.draft_to_target_token(di as u32)?);
}
per_step_top_k_target.push(top_target.clone());
hidden_in = out_hidden;
current_target_id = top_target[0];
let _ = root_draft_id; }
let full_tree = crate::methods::medusa::MedusaHeads::from_config(
crate::methods::medusa::MedusaConfig {
n_heads: config.draft_depth,
hidden_size: draft.config().hidden_size,
vocab_size: draft.config().target_vocab_size,
residual_layers: 1,
},
)
.build_draft_tree(
root_token,
&per_step_top_k_target,
crate::methods::medusa::TreeTopology::CartesianProduct,
)?;
let tree = if let Some(max_n) = config.max_tree_nodes {
crate::methods::eagle::prune_cartesian_tree_pub(
&full_tree,
&per_step_top_k_log_probs,
max_n,
)?
} else {
full_tree
};
let strict_root_logits: Option<Vec<f32>> = if config.strict_root_gemv {
Some(target.next_logits()?)
} else {
None
};
let (mut per_node_logits, _per_node_hidden) = target.tree_logits_keep_kv(&tree)?;
if let Some(root_gemv) = strict_root_logits {
per_node_logits[0] = root_gemv;
}
let mut best_path: Vec<usize> = vec![0];
for path in tree.paths() {
let mut accepted = 0;
for w in path.windows(2) {
let parent = w[0];
let child = w[1];
let candidate = tree.token_at(child) as usize;
let parent_dist = &per_node_logits[parent];
let argmax = parent_dist
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv {
(i, v)
} else {
(bi, bv)
}
})
.0;
if argmax == candidate {
accepted += 1;
} else {
break;
}
}
if accepted + 1 > best_path.len() {
best_path = path[..=accepted].to_vec();
}
}
let deepest_idx = *best_path.last().unwrap();
let bonus = per_node_logits[deepest_idx]
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv {
(i, v)
} else {
(bi, bv)
}
})
.0 as u32;
let path_committed: Vec<u32> = best_path
.iter()
.skip(1)
.map(|&i| tree.token_at(i))
.collect();
let mut committed = path_committed.clone();
committed.push(bonus);
let eos_set = target.eos_token_ids();
let eos_pos = committed.iter().position(|t| eos_set.contains(t));
let stop = eos_pos.is_some();
if let Some(p) = eos_pos {
committed.truncate(p + 1);
}
let path_eos_index = path_committed.iter().position(|t| eos_set.contains(t));
if let Some(idx) = path_eos_index {
let kept_path: Vec<usize> = best_path[..=idx + 1].to_vec();
target.commit_tree_path(&tree, &kept_path)?;
} else {
target.commit_tree_path(&tree, &best_path)?;
target.observe(&[bonus])?;
}
generated.extend_from_slice(&committed);
if stop {
break;
}
}
let _ = (rng, config.temperature, config.top_p);
generated.truncate(max_new_tokens);
Ok(generated)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_defaults_eagle3_llama3_1() {
let c = Eagle3DraftConfig::eagle3_llama3_1_8b();
assert_eq!(c.hidden_size, 4096);
assert_eq!(c.draft_vocab_size, 32000);
assert_eq!(c.target_vocab_size, 128256);
assert_eq!(c.head_dim(), 128);
}
#[test]
fn t2d_mask_derives_from_d2t() {
let target_vocab_size = 10;
let d2t_host: Vec<i64> = vec![0, 3, 5, 9, -1]; let mut target_to_draft = std::collections::HashMap::new();
let mut t2d_mask = vec![false; target_vocab_size];
for (di, &ti) in d2t_host.iter().enumerate() {
if ti >= 0 && (ti as u32) < target_vocab_size as u32 {
target_to_draft.insert(ti as u32, di as u32);
t2d_mask[ti as usize] = true;
}
}
assert_eq!(t2d_mask, vec![true, false, false, true, false, true, false, false, false, true]);
}
#[test]
fn default_layer_indices_for_8b() {
let l = Eagle3RunConfig::default_layers_for(32);
assert_eq!(l, [1, 15, 28]);
}
}