use super::{TransformerError, TransformerResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AcceptanceMethod {
#[default]
RejectionSampling,
TypicalAcceptance,
TreeVerification,
}
#[derive(Debug, Clone)]
pub struct DraftModelConfig {
pub draft_model_name: String,
pub draft_vocab_size: usize,
pub num_draft_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct SpeculativeDecoderConfig {
pub draft_model: DraftModelConfig,
pub num_speculative_tokens: usize,
pub acceptance_method: AcceptanceMethod,
pub max_tree_width: usize,
pub max_tree_depth: usize,
}
impl Default for SpeculativeDecoderConfig {
fn default() -> Self {
Self {
draft_model: DraftModelConfig {
draft_model_name: String::new(),
draft_vocab_size: 32000,
num_draft_tokens: 5,
},
num_speculative_tokens: 5,
acceptance_method: AcceptanceMethod::RejectionSampling,
max_tree_width: 3,
max_tree_depth: 5,
}
}
}
#[derive(Debug, Clone)]
pub struct SpeculativeOutput {
pub accepted_tokens: Vec<u32>,
pub num_accepted: usize,
pub num_drafted: usize,
pub correction_token: Option<u32>,
pub beneficial: bool,
}
impl SpeculativeOutput {
pub fn acceptance_rate(&self) -> f64 {
if self.num_drafted == 0 {
return 0.0;
}
self.num_accepted as f64 / self.num_drafted as f64
}
pub fn total_generated(&self) -> usize {
self.num_accepted
+ if self.correction_token.is_some() {
1
} else {
0
}
}
}
#[derive(Debug)]
pub struct SpeculativeDecoder {
config: SpeculativeDecoderConfig,
stats: SpeculativeStats,
}
#[derive(Debug, Clone, Default)]
pub struct SpeculativeStats {
pub total_steps: u64,
pub total_drafted: u64,
pub total_accepted: u64,
pub total_corrections: u64,
}
impl SpeculativeStats {
pub fn acceptance_rate(&self) -> f64 {
if self.total_drafted == 0 {
return 0.0;
}
self.total_accepted as f64 / self.total_drafted as f64
}
pub fn avg_tokens_per_step(&self) -> f64 {
if self.total_steps == 0 {
return 0.0;
}
(self.total_accepted + self.total_corrections) as f64 / self.total_steps as f64
}
pub fn estimated_speedup(&self) -> f64 {
if self.total_steps == 0 {
return 1.0;
}
self.avg_tokens_per_step()
}
}
impl SpeculativeDecoder {
pub fn new(config: SpeculativeDecoderConfig) -> TransformerResult<Self> {
if config.num_speculative_tokens == 0 {
return Err(TransformerError::SpeculativeError(
"num_speculative_tokens must be > 0".to_string(),
));
}
if config.draft_model.draft_vocab_size == 0 {
return Err(TransformerError::SpeculativeError(
"draft_vocab_size must be > 0".to_string(),
));
}
if config.acceptance_method == AcceptanceMethod::TreeVerification
&& (config.max_tree_width == 0 || config.max_tree_depth == 0)
{
return Err(TransformerError::SpeculativeError(
"tree dimensions must be > 0".to_string(),
));
}
Ok(Self {
config,
stats: SpeculativeStats::default(),
})
}
#[allow(clippy::too_many_arguments)]
pub fn verify(
&mut self,
draft_tokens: &[u32],
draft_probs: &[f64],
target_probs: &[Vec<f64>],
) -> TransformerResult<SpeculativeOutput> {
if draft_tokens.len() != draft_probs.len() {
return Err(TransformerError::SpeculativeError(
"draft_tokens and draft_probs must have same length".to_string(),
));
}
if target_probs.len() < draft_tokens.len() {
return Err(TransformerError::SpeculativeError(
"target_probs must have at least as many positions as draft_tokens".to_string(),
));
}
match self.config.acceptance_method {
AcceptanceMethod::RejectionSampling => {
self.verify_rejection(draft_tokens, draft_probs, target_probs)
}
AcceptanceMethod::TypicalAcceptance => {
self.verify_typical(draft_tokens, draft_probs, target_probs)
}
AcceptanceMethod::TreeVerification => {
self.verify_rejection(draft_tokens, draft_probs, target_probs)
}
}
}
pub fn build_tree(&self, root_probs: &[f64], depth: usize) -> TransformerResult<Vec<Vec<u32>>> {
if root_probs.is_empty() {
return Err(TransformerError::SpeculativeError(
"empty probability distribution".to_string(),
));
}
let effective_depth = depth.min(self.config.max_tree_depth);
let effective_width = self.config.max_tree_width;
let mut paths = Vec::new();
let top_k = self.top_k_indices(root_probs, effective_width);
for &token_id in &top_k {
let mut path = vec![token_id];
self.extend_path(&mut path, effective_depth - 1, effective_width);
paths.push(path);
}
Ok(paths)
}
pub fn stats(&self) -> &SpeculativeStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = SpeculativeStats::default();
}
pub fn config(&self) -> &SpeculativeDecoderConfig {
&self.config
}
pub fn num_speculative_tokens(&self) -> usize {
self.config.num_speculative_tokens
}
fn verify_rejection(
&mut self,
draft_tokens: &[u32],
draft_probs: &[f64],
target_probs: &[Vec<f64>],
) -> TransformerResult<SpeculativeOutput> {
let mut accepted_tokens = Vec::new();
let mut num_accepted = 0usize;
for (i, (&token, &draft_p)) in draft_tokens.iter().zip(draft_probs.iter()).enumerate() {
let target_p = target_probs
.get(i)
.and_then(|probs| probs.get(token as usize).copied())
.unwrap_or(0.0);
if draft_p <= 0.0 || target_p >= draft_p {
accepted_tokens.push(token);
num_accepted += 1;
} else {
let ratio = target_p / draft_p;
if ratio >= 0.5 {
accepted_tokens.push(token);
num_accepted += 1;
} else {
break;
}
}
}
let correction_pos = num_accepted;
let correction_token = if correction_pos < target_probs.len() {
let probs = &target_probs[correction_pos];
Some(self.sample_argmax(probs))
} else {
None
};
let num_drafted = draft_tokens.len();
self.stats.total_steps += 1;
self.stats.total_drafted += num_drafted as u64;
self.stats.total_accepted += num_accepted as u64;
if correction_token.is_some() {
self.stats.total_corrections += 1;
}
Ok(SpeculativeOutput {
accepted_tokens,
num_accepted,
num_drafted,
correction_token,
beneficial: num_accepted > 0,
})
}
fn verify_typical(
&mut self,
draft_tokens: &[u32],
draft_probs: &[f64],
target_probs: &[Vec<f64>],
) -> TransformerResult<SpeculativeOutput> {
let mut accepted_tokens = Vec::new();
let mut num_accepted = 0usize;
for (i, (&token, &draft_p)) in draft_tokens.iter().zip(draft_probs.iter()).enumerate() {
let target_p = target_probs
.get(i)
.and_then(|probs| probs.get(token as usize).copied())
.unwrap_or(0.0);
let draft_entropy = self.token_surprisal(draft_p);
let target_entropy = self.token_surprisal(target_p);
let entropy_diff = (draft_entropy - target_entropy).abs();
if entropy_diff < 2.0 && target_p > 1e-8 {
accepted_tokens.push(token);
num_accepted += 1;
} else {
break;
}
}
let correction_pos = num_accepted;
let correction_token = if correction_pos < target_probs.len() {
let probs = &target_probs[correction_pos];
Some(self.sample_argmax(probs))
} else {
None
};
let num_drafted = draft_tokens.len();
self.stats.total_steps += 1;
self.stats.total_drafted += num_drafted as u64;
self.stats.total_accepted += num_accepted as u64;
if correction_token.is_some() {
self.stats.total_corrections += 1;
}
Ok(SpeculativeOutput {
accepted_tokens,
num_accepted,
num_drafted,
correction_token,
beneficial: num_accepted > 0,
})
}
fn token_surprisal(&self, prob: f64) -> f64 {
if prob <= 0.0 {
return f64::MAX;
}
-prob.ln()
}
fn sample_argmax(&self, probs: &[f64]) -> u32 {
if probs.is_empty() {
return 0;
}
let mut max_idx = 0usize;
let mut max_val = f64::NEG_INFINITY;
for (i, &p) in probs.iter().enumerate() {
if p > max_val {
max_val = p;
max_idx = i;
}
}
max_idx as u32
}
fn top_k_indices(&self, probs: &[f64], k: usize) -> Vec<u32> {
let mut indexed: Vec<(usize, f64)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.iter().take(k).map(|(i, _)| *i as u32).collect()
}
fn extend_path(&self, _path: &mut [u32], _remaining_depth: usize, _width: usize) {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> SpeculativeDecoderConfig {
SpeculativeDecoderConfig {
draft_model: DraftModelConfig {
draft_model_name: "test-draft".to_string(),
draft_vocab_size: 100,
num_draft_tokens: 5,
},
num_speculative_tokens: 5,
acceptance_method: AcceptanceMethod::RejectionSampling,
max_tree_width: 3,
max_tree_depth: 5,
}
}
#[test]
fn test_create_decoder() {
let decoder = SpeculativeDecoder::new(default_config()).unwrap();
assert_eq!(decoder.num_speculative_tokens(), 5);
}
#[test]
fn test_invalid_config() {
let mut cfg = default_config();
cfg.num_speculative_tokens = 0;
assert!(SpeculativeDecoder::new(cfg).is_err());
let mut cfg = default_config();
cfg.draft_model.draft_vocab_size = 0;
assert!(SpeculativeDecoder::new(cfg).is_err());
}
#[test]
fn test_verify_all_accept() {
let mut decoder = SpeculativeDecoder::new(default_config()).unwrap();
let draft_tokens = vec![1, 2, 3];
let draft_probs = vec![0.3, 0.4, 0.5];
let target_probs = vec![
vec![0.0, 0.8, 0.1, 0.1], vec![0.0, 0.0, 0.9, 0.1], vec![0.0, 0.0, 0.0, 0.7], vec![0.5, 0.3, 0.1, 0.1], ];
let output = decoder
.verify(&draft_tokens, &draft_probs, &target_probs)
.unwrap();
assert_eq!(output.num_accepted, 3);
assert!(output.beneficial);
assert!(output.correction_token.is_some());
}
#[test]
fn test_verify_partial_accept() {
let mut decoder = SpeculativeDecoder::new(default_config()).unwrap();
let draft_tokens = vec![1, 2, 3];
let draft_probs = vec![0.3, 0.8, 0.5];
let target_probs = vec![
vec![0.0, 0.5, 0.3, 0.2], vec![0.0, 0.0, 0.01, 0.9], vec![0.0, 0.0, 0.0, 0.9],
];
let output = decoder
.verify(&draft_tokens, &draft_probs, &target_probs)
.unwrap();
assert!(output.num_accepted >= 1);
assert!(output.num_accepted < 3);
}
#[test]
fn test_verify_none_accept() {
let mut decoder = SpeculativeDecoder::new(default_config()).unwrap();
let draft_tokens = vec![1];
let draft_probs = vec![0.9];
let target_probs = vec![
vec![0.0, 0.001], vec![0.5, 0.5],
];
let output = decoder
.verify(&draft_tokens, &draft_probs, &target_probs)
.unwrap();
assert_eq!(output.num_accepted, 0);
}
#[test]
fn test_verify_typical() {
let cfg = SpeculativeDecoderConfig {
acceptance_method: AcceptanceMethod::TypicalAcceptance,
..default_config()
};
let mut decoder = SpeculativeDecoder::new(cfg).unwrap();
let draft_tokens = vec![1, 2];
let draft_probs = vec![0.3, 0.3];
let target_probs = vec![
vec![0.0, 0.35, 0.3, 0.35],
vec![0.0, 0.1, 0.4, 0.5],
vec![0.5, 0.3, 0.1, 0.1],
];
let output = decoder
.verify(&draft_tokens, &draft_probs, &target_probs)
.unwrap();
assert!(output.num_drafted == 2);
}
#[test]
fn test_speculative_output_acceptance_rate() {
let output = SpeculativeOutput {
accepted_tokens: vec![1, 2, 3],
num_accepted: 3,
num_drafted: 5,
correction_token: Some(4),
beneficial: true,
};
assert!((output.acceptance_rate() - 0.6).abs() < 1e-10);
assert_eq!(output.total_generated(), 4);
}
#[test]
fn test_stats_tracking() {
let mut decoder = SpeculativeDecoder::new(default_config()).unwrap();
let draft_tokens = vec![1, 2];
let draft_probs = vec![0.3, 0.3];
let target_probs = vec![vec![0.0, 0.8], vec![0.0, 0.0, 0.8], vec![0.5, 0.5]];
let _ = decoder.verify(&draft_tokens, &draft_probs, &target_probs);
assert_eq!(decoder.stats().total_steps, 1);
assert_eq!(decoder.stats().total_drafted, 2);
assert!(decoder.stats().total_accepted > 0);
}
#[test]
fn test_stats_reset() {
let mut decoder = SpeculativeDecoder::new(default_config()).unwrap();
let _ = decoder.verify(&[1], &[0.5], &[vec![0.0, 0.8], vec![0.5, 0.5]]);
assert!(decoder.stats().total_steps > 0);
decoder.reset_stats();
assert_eq!(decoder.stats().total_steps, 0);
}
#[test]
fn test_build_tree() {
let decoder = SpeculativeDecoder::new(default_config()).unwrap();
let probs = vec![0.1, 0.4, 0.3, 0.2];
let paths = decoder.build_tree(&probs, 3).unwrap();
assert!(!paths.is_empty());
assert!(paths.len() <= 3); }
#[test]
fn test_build_tree_empty() {
let decoder = SpeculativeDecoder::new(default_config()).unwrap();
assert!(decoder.build_tree(&[], 3).is_err());
}
#[test]
fn test_verify_mismatched_lengths() {
let mut decoder = SpeculativeDecoder::new(default_config()).unwrap();
assert!(decoder.verify(&[1, 2], &[0.5], &[]).is_err());
}
#[test]
fn test_estimated_speedup() {
let stats = SpeculativeStats {
total_steps: 10,
total_drafted: 50,
total_accepted: 30,
total_corrections: 10,
};
let speedup = stats.estimated_speedup();
assert!((speedup - 4.0).abs() < 1e-10);
}
}