use crate::error::{AttentionError, AttentionResult};
pub type TokenId = u32;
#[derive(Clone, Debug)]
pub struct SpeculativeConfig {
pub gamma: usize,
pub temperature: f32,
pub top_p: f32,
pub max_seq_len: usize,
}
impl SpeculativeConfig {
pub fn new(gamma: usize) -> Self {
Self {
gamma,
temperature: 1.0,
top_p: 1.0,
max_seq_len: 2048,
}
}
pub fn validate(&self) -> AttentionResult<()> {
let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
if self.gamma == 0 {
return err("gamma must be > 0");
}
if self.gamma > 32 {
return err("gamma must be <= 32");
}
if self.temperature <= 0.0 {
return err("temperature must be > 0");
}
if self.top_p <= 0.0 || self.top_p > 1.0 {
return err("top_p must be in (0, 1]");
}
if self.max_seq_len == 0 {
return err("max_seq_len must be > 0");
}
Ok(())
}
}
pub trait DraftModel: Send + Sync {
fn draft_tokens(
&self,
prefix: &[TokenId],
gamma: usize,
) -> Vec<(TokenId, f32)>;
}
pub trait TargetModel: Send + Sync {
fn verify_batch(
&self,
prefix: &[TokenId],
draft_tokens: &[TokenId],
) -> Vec<Vec<(TokenId, f32)>>;
}
#[derive(Clone, Debug)]
pub struct AcceptedTokens {
pub tokens: Vec<TokenId>,
pub acceptance_rate: f32,
pub draft_calls: usize,
pub target_calls: usize,
}
#[derive(Clone, Debug, Default)]
pub struct DecodingStats {
pub tokens_generated: usize,
pub acceptance_rate: f32,
pub speedup_ratio: f32,
pub draft_latency_ms: f64,
pub target_latency_ms: f64,
}
pub fn theoretical_speedup(gamma: usize, acceptance_rate: f32) -> f32 {
let g = gamma as f32;
let a = acceptance_rate.clamp(0.0, 1.0);
let denominator = 1.0 + g * (1.0 - a);
if denominator <= 0.0 {
return 0.0;
}
(g * a) / denominator
}
pub struct SpeculativeDecoder;
impl SpeculativeDecoder {
pub fn decode_step(
prefix: &[TokenId],
draft: &dyn DraftModel,
target: &dyn TargetModel,
config: &SpeculativeConfig,
rng_values: Option<&[f32]>,
) -> AttentionResult<AcceptedTokens> {
config.validate()?;
let draft_results = draft.draft_tokens(prefix, config.gamma);
if draft_results.is_empty() {
return Err(AttentionError::EmptyInput(
"draft model returned no tokens".into(),
));
}
let draft_tokens: Vec<TokenId> =
draft_results.iter().map(|(t, _)| *t).collect();
let draft_probs: Vec<f32> =
draft_results.iter().map(|(_, p)| *p).collect();
let target_dists = target.verify_batch(prefix, &draft_tokens);
if target_dists.len() < draft_tokens.len() + 1 {
return Err(AttentionError::ComputationError(
"target model must return gamma+1 distributions".into(),
));
}
let mut accepted = Vec::new();
let mut rejected = false;
for i in 0..draft_tokens.len() {
let token = draft_tokens[i];
let q_i = draft_probs[i];
let p_i = prob_of_token(&target_dists[i], token);
let rng_val = rng_values
.and_then(|v| v.get(i).copied())
.unwrap_or(0.0);
if p_i >= q_i {
accepted.push(token);
} else if rng_val < p_i / q_i {
accepted.push(token);
} else {
let adjusted = sample_adjusted(
&target_dists[i],
&draft_tokens,
&draft_probs,
i,
);
accepted.push(adjusted);
rejected = true;
break;
}
}
if !rejected {
let bonus_dist = &target_dists[draft_tokens.len()];
if let Some(&(token, _)) = bonus_dist.first() {
accepted.push(token);
}
}
let num_draft = draft_tokens.len();
let num_accepted_from_draft = if rejected {
accepted.len().saturating_sub(1)
} else {
num_draft
};
let acceptance_rate = if num_draft > 0 {
num_accepted_from_draft as f32 / num_draft as f32
} else {
0.0
};
Ok(AcceptedTokens {
tokens: accepted,
acceptance_rate,
draft_calls: 1,
target_calls: 1,
})
}
}
fn prob_of_token(dist: &[(TokenId, f32)], token: TokenId) -> f32 {
dist.iter()
.find(|(t, _)| *t == token)
.map(|(_, p)| *p)
.unwrap_or(0.0)
}
fn sample_adjusted(
target_dist: &[(TokenId, f32)],
draft_tokens: &[TokenId],
draft_probs: &[f32],
position: usize,
) -> TokenId {
let mut best_token = target_dist
.first()
.map(|(t, _)| *t)
.unwrap_or(0);
let mut best_score = f32::NEG_INFINITY;
for &(token, p_target) in target_dist {
let p_draft = if token == draft_tokens[position] {
draft_probs[position]
} else {
0.0
};
let adjusted = (p_target - p_draft).max(0.0);
if adjusted > best_score {
best_score = adjusted;
best_token = token;
}
}
best_token
}
pub trait MedusaHead: Send + Sync {
fn predict(&self, prefix: &[TokenId]) -> Vec<(TokenId, f32)>;
}
#[derive(Clone, Debug)]
pub struct MedusaResult {
pub tokens: Vec<TokenId>,
pub paths_evaluated: usize,
}
pub fn medusa_decode(
prefix: &[TokenId],
heads: &[&dyn MedusaHead],
target: &dyn TargetModel,
config: &SpeculativeConfig,
) -> AttentionResult<MedusaResult> {
config.validate()?;
if heads.is_empty() {
return Err(AttentionError::EmptyInput(
"at least one Medusa head required".into(),
));
}
let head_predictions: Vec<Vec<(TokenId, f32)>> = heads
.iter()
.map(|h| h.predict(prefix))
.collect();
let candidate_path: Vec<TokenId> = head_predictions
.iter()
.filter_map(|dist| dist.first().map(|(t, _)| *t))
.collect();
if candidate_path.is_empty() {
return Err(AttentionError::EmptyInput(
"heads produced no predictions".into(),
));
}
let target_dists = target.verify_batch(prefix, &candidate_path);
let mut accepted = Vec::new();
for (i, &token) in candidate_path.iter().enumerate() {
if i >= target_dists.len() {
break;
}
let p = prob_of_token(&target_dists[i], token);
if p > 0.0 {
accepted.push(token);
} else {
break;
}
}
if accepted.is_empty() {
if let Some(dist) = target_dists.first() {
if let Some(&(token, _)) = dist.first() {
accepted.push(token);
}
}
}
Ok(MedusaResult {
tokens: accepted,
paths_evaluated: 1, })
}
pub struct SimpleDraftModel {
pub tokens: Vec<TokenId>,
pub probability: f32,
}
impl DraftModel for SimpleDraftModel {
fn draft_tokens(
&self,
_prefix: &[TokenId],
gamma: usize,
) -> Vec<(TokenId, f32)> {
(0..gamma)
.map(|i| {
let token = self.tokens[i % self.tokens.len()];
(token, self.probability)
})
.collect()
}
}
pub struct SimpleTargetModel {
pub distributions: Vec<Vec<(TokenId, f32)>>,
}
impl TargetModel for SimpleTargetModel {
fn verify_batch(
&self,
_prefix: &[TokenId],
draft_tokens: &[TokenId],
) -> Vec<Vec<(TokenId, f32)>> {
let needed = draft_tokens.len() + 1;
(0..needed)
.map(|i| {
if i < self.distributions.len() {
self.distributions[i].clone()
} else {
self.distributions
.last()
.cloned()
.unwrap_or_else(|| vec![(0, 1.0)])
}
})
.collect()
}
}
pub struct SimpleMedusaHead {
pub token: TokenId,
pub probability: f32,
}
impl MedusaHead for SimpleMedusaHead {
fn predict(&self, _prefix: &[TokenId]) -> Vec<(TokenId, f32)> {
vec![(self.token, self.probability)]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> SpeculativeConfig {
SpeculativeConfig::new(4)
}
#[test]
fn test_config_valid() {
assert!(default_config().validate().is_ok());
}
#[test]
fn test_config_gamma_zero() {
let mut cfg = default_config();
cfg.gamma = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn test_config_gamma_too_large() {
let mut cfg = default_config();
cfg.gamma = 33;
assert!(cfg.validate().is_err());
}
#[test]
fn test_config_bad_temperature() {
let mut cfg = default_config();
cfg.temperature = 0.0;
assert!(cfg.validate().is_err());
}
#[test]
fn test_config_bad_top_p() {
let mut cfg = default_config();
cfg.top_p = 0.0;
assert!(cfg.validate().is_err());
cfg.top_p = 1.1;
assert!(cfg.validate().is_err());
}
#[test]
fn test_full_acceptance() {
let draft = SimpleDraftModel {
tokens: vec![10, 20, 30, 40],
probability: 0.5,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(10, 0.8)],
vec![(20, 0.7)],
vec![(30, 0.6)],
vec![(40, 0.9)],
vec![(50, 1.0)], ],
};
let result = SpeculativeDecoder::decode_step(
&[1, 2, 3],
&draft,
&target,
&default_config(),
None,
)
.unwrap();
assert_eq!(result.tokens.len(), 5);
assert_eq!(result.tokens, vec![10, 20, 30, 40, 50]);
assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_full_rejection() {
let draft = SimpleDraftModel {
tokens: vec![10, 20, 30, 40],
probability: 0.9,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(99, 0.9)],
vec![(99, 0.9)],
vec![(99, 0.9)],
vec![(99, 0.9)],
vec![(99, 1.0)],
],
};
let result = SpeculativeDecoder::decode_step(
&[1],
&draft,
&target,
&default_config(),
Some(&[1.0, 1.0, 1.0, 1.0]), )
.unwrap();
assert_eq!(result.tokens.len(), 1);
assert_eq!(result.tokens[0], 99);
assert!((result.acceptance_rate - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_partial_acceptance() {
let draft = SimpleDraftModel {
tokens: vec![10, 20, 30, 40],
probability: 0.5,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(10, 0.8)],
vec![(20, 0.6)],
vec![(77, 0.9)], vec![(40, 0.9)],
vec![(50, 1.0)],
],
};
let result = SpeculativeDecoder::decode_step(
&[1],
&draft,
&target,
&default_config(),
Some(&[0.0, 0.0, 1.0, 0.0]), )
.unwrap();
assert_eq!(result.tokens.len(), 3);
assert_eq!(result.tokens[0], 10);
assert_eq!(result.tokens[1], 20);
assert_eq!(result.tokens[2], 77);
assert!((result.acceptance_rate - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_rejection_sampling_distribution() {
let draft = SimpleDraftModel {
tokens: vec![10],
probability: 0.8,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(10, 0.3), (42, 0.7)],
vec![(99, 1.0)],
],
};
let cfg = SpeculativeConfig::new(1);
let result = SpeculativeDecoder::decode_step(
&[1],
&draft,
&target,
&cfg,
Some(&[1.0]), )
.unwrap();
assert_eq!(result.tokens.len(), 1);
assert_eq!(result.tokens[0], 42);
}
#[test]
fn test_theoretical_speedup() {
let s = theoretical_speedup(4, 1.0);
assert!((s - 4.0).abs() < 1e-5);
let s = theoretical_speedup(4, 0.0);
assert!(s.abs() < 1e-5);
let s = theoretical_speedup(4, 0.8);
assert!((s - 3.2 / 1.8).abs() < 1e-4);
let s = theoretical_speedup(8, 0.9);
assert!((s - 7.2 / 1.8).abs() < 1e-4);
}
#[test]
fn test_medusa_decode() {
let h1 = SimpleMedusaHead {
token: 10,
probability: 0.9,
};
let h2 = SimpleMedusaHead {
token: 20,
probability: 0.8,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(10, 0.7)],
vec![(20, 0.6)],
vec![(99, 1.0)],
],
};
let heads: Vec<&dyn MedusaHead> = vec![&h1, &h2];
let result =
medusa_decode(&[1, 2], &heads, &target, &default_config()).unwrap();
assert_eq!(result.tokens, vec![10, 20]);
assert_eq!(result.paths_evaluated, 1);
}
#[test]
fn test_medusa_no_heads() {
let target = SimpleTargetModel {
distributions: vec![vec![(1, 1.0)]],
};
let heads: Vec<&dyn MedusaHead> = vec![];
let result =
medusa_decode(&[1], &heads, &target, &default_config());
assert!(result.is_err());
}
#[test]
fn test_probabilistic_acceptance() {
let draft = SimpleDraftModel {
tokens: vec![10],
probability: 0.8,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(10, 0.4)], vec![(99, 1.0)],
],
};
let cfg = SpeculativeConfig::new(1);
let result = SpeculativeDecoder::decode_step(
&[1],
&draft,
&target,
&cfg,
Some(&[0.3]),
)
.unwrap();
assert_eq!(result.tokens, vec![10, 99]);
assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_empty_prefix() {
let draft = SimpleDraftModel {
tokens: vec![5],
probability: 0.5,
};
let target = SimpleTargetModel {
distributions: vec![
vec![(5, 0.9)],
vec![(6, 1.0)],
],
};
let cfg = SpeculativeConfig::new(1);
let result = SpeculativeDecoder::decode_step(
&[],
&draft,
&target,
&cfg,
None,
)
.unwrap();
assert_eq!(result.tokens, vec![5, 6]);
}
}