use crate::context::ContextParams;
use crate::error::MullamaError;
use crate::token::TokenId;
use crate::{Context, Model};
use std::sync::Arc;
pub type Token = TokenId;
pub struct SpeculativeDecoder {
target_model: Arc<Model>,
draft_model: Arc<Model>,
target_context: Context,
draft_context: Context,
config: SpeculativeConfig,
stats: SpeculativeStats,
}
impl std::fmt::Debug for SpeculativeDecoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpeculativeDecoder")
.field("config", &self.config)
.field("stats", &self.stats)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub lookahead_tokens: usize,
pub acceptance_threshold: f32,
pub max_rejections: usize,
pub draft_temperature: f32,
pub target_temperature: f32,
pub dynamic_lookahead: bool,
pub batch_size: usize,
}
#[derive(Debug, Clone, Default)]
pub struct SpeculativeStats {
pub total_tokens: usize,
pub accepted_tokens: usize,
pub rejected_tokens: usize,
pub avg_lookahead: f32,
pub draft_time_ns: u64,
pub target_time_ns: u64,
pub speculation_rounds: usize,
}
#[derive(Debug)]
pub struct SpeculativeResult {
pub tokens: Vec<Token>,
pub should_continue: bool,
pub stats: SpeculativeStats,
}
#[derive(Debug, Clone)]
pub struct CandidateToken {
pub token: Token,
pub log_prob: f32,
pub probability: f32,
}
#[derive(Debug)]
pub struct DraftProposal {
pub candidates: Vec<CandidateToken>,
pub draft_context_state: Vec<u8>, }
impl SpeculativeDecoder {
pub fn new(
target_model: Arc<Model>,
draft_model: Arc<Model>,
config: SpeculativeConfig,
) -> Result<Self, MullamaError> {
Self::validate_models(&target_model, &draft_model)?;
let target_context = Context::new(target_model.clone(), ContextParams::default())?;
let draft_context = Context::new(draft_model.clone(), ContextParams::default())?;
Ok(Self {
target_model,
draft_model,
target_context,
draft_context,
config,
stats: SpeculativeStats::default(),
})
}
pub fn generate(
&mut self,
prompt_tokens: &[Token],
max_tokens: usize,
) -> Result<Vec<Token>, MullamaError> {
self.initialize_contexts(prompt_tokens)?;
let mut generated_tokens = Vec::new();
while generated_tokens.len() < max_tokens {
let result = self.speculative_step()?;
generated_tokens.extend(result.tokens);
if !result.should_continue {
break;
}
if self.config.dynamic_lookahead {
self.adjust_lookahead();
}
}
Ok(generated_tokens)
}
pub fn speculative_step(&mut self) -> Result<SpeculativeResult, MullamaError> {
let draft_start = std::time::Instant::now();
let proposal = self.generate_draft_proposal()?;
self.stats.draft_time_ns += draft_start.elapsed().as_nanos() as u64;
let target_start = std::time::Instant::now();
let (accepted_tokens, should_continue) = self.validate_with_target(&proposal)?;
self.stats.target_time_ns += target_start.elapsed().as_nanos() as u64;
self.stats.speculation_rounds += 1;
self.stats.total_tokens += accepted_tokens.len();
self.stats.accepted_tokens += accepted_tokens.len();
if accepted_tokens.len() < proposal.candidates.len() {
self.stats.rejected_tokens += proposal.candidates.len() - accepted_tokens.len();
}
self.update_contexts(&accepted_tokens)?;
Ok(SpeculativeResult {
tokens: accepted_tokens,
should_continue,
stats: self.stats.clone(),
})
}
fn generate_draft_proposal(&mut self) -> Result<DraftProposal, MullamaError> {
let mut candidates = Vec::new();
for _ in 0..self.config.lookahead_tokens {
let logits = self.draft_context.get_logits();
if logits.is_empty() {
return Err(MullamaError::GenerationError(
"Empty logits from draft model".to_string(),
));
}
let scaled_logits = self.apply_temperature(logits, self.config.draft_temperature);
let token = self.sample_from_logits(&scaled_logits)?;
let log_prob = scaled_logits[token as usize];
let probability = log_prob.exp();
candidates.push(CandidateToken {
token,
log_prob,
probability,
});
self.draft_context.decode(&[token])?;
if self.draft_model.token_is_eog(token) {
break;
}
}
let draft_context_state = self.draft_context.save_state();
Ok(DraftProposal {
candidates,
draft_context_state,
})
}
fn validate_with_target(
&mut self,
proposal: &DraftProposal,
) -> Result<(Vec<Token>, bool), MullamaError> {
let mut accepted_tokens = Vec::new();
let mut should_continue = true;
for (i, candidate) in proposal.candidates.iter().enumerate() {
let target_logits = self.target_context.get_logits();
if target_logits.is_empty() {
return Err(MullamaError::GenerationError(
"Empty logits from target model".to_string(),
));
}
let target_scaled_logits =
self.apply_temperature(target_logits, self.config.target_temperature);
let target_prob = target_scaled_logits[candidate.token as usize].exp();
let acceptance_ratio = target_prob / candidate.probability;
let random_value: f32 = self.simple_random();
if random_value < acceptance_ratio.min(1.0)
&& acceptance_ratio >= self.config.acceptance_threshold
{
accepted_tokens.push(candidate.token);
self.target_context.decode(&[candidate.token])?;
if self.target_model.token_is_eog(candidate.token) {
should_continue = false;
break;
}
} else {
let corrected_logits =
self.correct_distribution(&target_scaled_logits, &proposal.candidates[..=i]);
let corrected_token = self.sample_from_logits(&corrected_logits)?;
accepted_tokens.push(corrected_token);
self.target_context.decode(&[corrected_token])?;
if self.target_model.token_is_eog(corrected_token) {
should_continue = false;
}
break;
}
}
Ok((accepted_tokens, should_continue))
}
fn apply_temperature(&self, logits: &[f32], temperature: f32) -> Vec<f32> {
if temperature == 1.0 {
return logits.to_vec();
}
logits.iter().map(|&logit| logit / temperature).collect()
}
fn sample_from_logits(&self, logits: &[f32]) -> Result<Token, MullamaError> {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
let probabilities: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let random_value: f32 = self.simple_random();
let mut cumulative = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
cumulative += prob;
if random_value <= cumulative {
return Ok(i as Token);
}
}
Ok((probabilities.len() - 1) as Token)
}
fn simple_random(&self) -> f32 {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.subsec_nanos();
((seed.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff) as f32
/ 0x7fffffff as f32
}
fn correct_distribution(
&self,
target_logits: &[f32],
rejected_candidates: &[CandidateToken],
) -> Vec<f32> {
let mut corrected_logits = target_logits.to_vec();
for candidate in rejected_candidates {
if (candidate.token as usize) < corrected_logits.len() {
corrected_logits[candidate.token as usize] -= 1.0; }
}
corrected_logits
}
fn initialize_contexts(&mut self, prompt_tokens: &[Token]) -> Result<(), MullamaError> {
self.target_context.kv_cache_clear();
self.draft_context.kv_cache_clear();
self.target_context.decode(prompt_tokens)?;
self.draft_context.decode(prompt_tokens)?;
Ok(())
}
fn update_contexts(&mut self, _accepted_tokens: &[Token]) -> Result<(), MullamaError> {
let target_state = self.target_context.save_state();
let _ = self.draft_context.load_state(&target_state);
Ok(())
}
fn adjust_lookahead(&mut self) {
if self.stats.speculation_rounds < 10 {
return; }
let acceptance_rate = self.stats.accepted_tokens as f32 / self.stats.total_tokens as f32;
if acceptance_rate > 0.8 {
self.config.lookahead_tokens = (self.config.lookahead_tokens + 1).min(8);
} else if acceptance_rate < 0.5 {
self.config.lookahead_tokens = (self.config.lookahead_tokens.saturating_sub(1)).max(1);
}
self.stats.avg_lookahead = (self.stats.avg_lookahead
* (self.stats.speculation_rounds - 1) as f32
+ self.config.lookahead_tokens as f32)
/ self.stats.speculation_rounds as f32;
}
fn validate_models(target: &Model, draft: &Model) -> Result<(), MullamaError> {
if target.vocab_size() != draft.vocab_size() {
return Err(MullamaError::InvalidInput(
"Target and draft models must have the same vocabulary size".to_string(),
));
}
if target.vocab_type() != draft.vocab_type() {
return Err(MullamaError::InvalidInput(
"Target and draft models must use the same vocabulary type".to_string(),
));
}
if draft.n_params() >= target.n_params() {
eprintln!("Warning: Draft model is not smaller than target model");
}
Ok(())
}
pub fn stats(&self) -> &SpeculativeStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = SpeculativeStats::default();
}
pub fn config(&self) -> &SpeculativeConfig {
&self.config
}
pub fn set_config(&mut self, config: SpeculativeConfig) {
self.config = config;
}
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
lookahead_tokens: 4,
acceptance_threshold: 0.6,
max_rejections: 10,
draft_temperature: 1.0,
target_temperature: 1.0,
dynamic_lookahead: true,
batch_size: 1,
}
}
}
impl SpeculativeConfig {
pub fn with_lookahead_tokens(mut self, tokens: usize) -> Self {
self.lookahead_tokens = tokens;
self
}
pub fn with_acceptance_threshold(mut self, threshold: f32) -> Self {
self.acceptance_threshold = threshold;
self
}
pub fn with_max_rejections(mut self, max: usize) -> Self {
self.max_rejections = max;
self
}
pub fn with_draft_temperature(mut self, temp: f32) -> Self {
self.draft_temperature = temp;
self
}
pub fn with_target_temperature(mut self, temp: f32) -> Self {
self.target_temperature = temp;
self
}
pub fn with_dynamic_lookahead(mut self, enabled: bool) -> Self {
self.dynamic_lookahead = enabled;
self
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
}
impl SpeculativeStats {
pub fn acceptance_rate(&self) -> f32 {
if self.total_tokens == 0 {
0.0
} else {
self.accepted_tokens as f32 / self.total_tokens as f32
}
}
pub fn speedup_factor(&self) -> f32 {
if self.speculation_rounds == 0 {
1.0
} else {
self.total_tokens as f32 / self.speculation_rounds as f32
}
}
pub fn total_time_seconds(&self) -> f64 {
(self.draft_time_ns + self.target_time_ns) as f64 / 1_000_000_000.0
}
pub fn tokens_per_second(&self) -> f64 {
let total_time = self.total_time_seconds();
if total_time > 0.0 {
self.total_tokens as f64 / total_time
} else {
0.0
}
}
}
pub mod utils {
use super::*;
pub fn find_optimal_lookahead(
target_model: Arc<Model>,
draft_model: Arc<Model>,
test_prompt: &[Token],
max_lookahead: usize,
) -> Result<usize, MullamaError> {
let mut best_lookahead = 1;
let mut best_speedup = 0.0;
for lookahead in 1..=max_lookahead {
let config = SpeculativeConfig::default()
.with_lookahead_tokens(lookahead)
.with_dynamic_lookahead(false);
let mut decoder = SpeculativeDecoder::new(
Arc::clone(&target_model),
Arc::clone(&draft_model),
config,
)?;
let _tokens = decoder.generate(test_prompt, 50)?;
let speedup = decoder.stats().speedup_factor();
if speedup > best_speedup {
best_speedup = speedup;
best_lookahead = lookahead;
}
}
Ok(best_lookahead)
}
pub fn speed_optimized_config() -> SpeculativeConfig {
SpeculativeConfig::default()
.with_lookahead_tokens(6)
.with_acceptance_threshold(0.5)
.with_draft_temperature(1.2)
.with_dynamic_lookahead(true)
}
pub fn quality_optimized_config() -> SpeculativeConfig {
SpeculativeConfig::default()
.with_lookahead_tokens(3)
.with_acceptance_threshold(0.8)
.with_draft_temperature(0.8)
.with_dynamic_lookahead(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_speculative_config() {
let config = SpeculativeConfig::default()
.with_lookahead_tokens(6)
.with_acceptance_threshold(0.8);
assert_eq!(config.lookahead_tokens, 6);
assert_eq!(config.acceptance_threshold, 0.8);
}
#[test]
fn test_speculative_stats() {
let mut stats = SpeculativeStats::default();
stats.total_tokens = 100;
stats.accepted_tokens = 80;
stats.speculation_rounds = 25;
assert_eq!(stats.acceptance_rate(), 0.8);
assert_eq!(stats.speedup_factor(), 4.0);
}
#[test]
fn test_candidate_token() {
let candidate = CandidateToken {
token: 42,
log_prob: -0.5,
probability: 0.606,
};
assert_eq!(candidate.token, 42);
assert!((candidate.probability - 0.606).abs() < 0.001);
}
#[test]
fn test_config_presets() {
let speed_config = utils::speed_optimized_config();
assert_eq!(speed_config.lookahead_tokens, 6);
assert!(speed_config.dynamic_lookahead);
let quality_config = utils::quality_optimized_config();
assert_eq!(quality_config.acceptance_threshold, 0.8);
assert!(!quality_config.dynamic_lookahead);
}
}