use crate::model::{InferenceContext, Model};
use crate::sampling::Sampler;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub num_speculative: usize,
pub draft_temperature: f32,
pub target_temperature: f32,
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
num_speculative: 4,
draft_temperature: 0.8,
target_temperature: 0.8,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SpeculativeStats {
pub total_tokens: usize,
pub accepted_tokens: usize,
pub rejected_tokens: usize,
pub batches: usize,
}
impl SpeculativeStats {
pub fn acceptance_rate(&self) -> f32 {
if self.total_tokens > 0 {
self.accepted_tokens as f32 / self.total_tokens as f32
} else {
0.0
}
}
pub fn avg_accepted_per_batch(&self) -> f32 {
if self.batches > 0 {
self.accepted_tokens as f32 / self.batches as f32
} else {
0.0
}
}
}
pub struct SpeculativeDecoder {
config: SpeculativeConfig,
stats: SpeculativeStats,
}
impl SpeculativeDecoder {
pub fn new(config: SpeculativeConfig) -> Self {
Self {
config,
stats: SpeculativeStats::default(),
}
}
pub fn stats(&self) -> &SpeculativeStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = SpeculativeStats::default();
}
pub fn generate(
&mut self,
draft_model: &dyn Model,
target_model: &dyn Model,
draft_ctx: &mut InferenceContext,
target_ctx: &mut InferenceContext,
draft_sampler: &mut Sampler,
target_sampler: &mut Sampler,
input_tokens: &[u32],
max_tokens: usize,
eos_token: u32,
) -> Result<Vec<u32>, Box<dyn std::error::Error>> {
let mut output_tokens = input_tokens.to_vec();
let mut generated = 0;
while generated < max_tokens {
let mut draft_tokens = Vec::with_capacity(self.config.num_speculative);
let mut draft_probs = Vec::with_capacity(self.config.num_speculative);
for _ in 0..self.config.num_speculative {
if output_tokens.len() + draft_tokens.len() >= draft_ctx.kv_cache.max_seq_len {
break;
}
let last_token = draft_tokens.last().copied().unwrap_or_else(|| {
*output_tokens.last().unwrap_or(&0)
});
let logits = draft_model.forward(&[last_token], draft_ctx)?;
let probs = softmax_logits(&logits)?;
let token = draft_sampler.sample(&logits, &output_tokens);
draft_tokens.push(token);
draft_probs.push(probs);
if token == eos_token {
break;
}
}
if draft_tokens.is_empty() {
break;
}
let mut accepted = 0;
for (i, &draft_token) in draft_tokens.iter().enumerate() {
let last_token = if i == 0 {
*output_tokens.last().unwrap_or(&0)
} else {
draft_tokens[i - 1]
};
let target_logits = target_model.forward(&[last_token], target_ctx)?;
let target_probs = softmax_logits(&target_logits)?;
let draft_prob = get_token_prob(&draft_probs[i], draft_token);
let target_prob = get_token_prob(&target_probs, draft_token);
let r: f32 = rand::random();
let accept = r * draft_prob <= target_prob;
if accept {
output_tokens.push(draft_token);
accepted += 1;
generated += 1;
self.stats.accepted_tokens += 1;
self.stats.total_tokens += 1;
if draft_token == eos_token || generated >= max_tokens {
break;
}
} else {
let adjusted_token = sample_adjusted_distribution(
&target_probs,
&draft_probs[i],
target_sampler,
&output_tokens,
);
output_tokens.push(adjusted_token);
generated += 1;
self.stats.rejected_tokens += 1;
self.stats.total_tokens += 1;
if adjusted_token == eos_token || generated >= max_tokens {
break;
}
break;
}
}
if accepted == draft_tokens.len() && generated < max_tokens {
let last_token = *output_tokens.last().unwrap_or(&0);
let target_logits = target_model.forward(&[last_token], target_ctx)?;
let bonus_token = target_sampler.sample(&target_logits, &output_tokens);
output_tokens.push(bonus_token);
generated += 1;
self.stats.total_tokens += 1;
self.stats.accepted_tokens += 1;
if bonus_token == eos_token {
break;
}
}
self.stats.batches += 1;
}
Ok(output_tokens)
}
}
fn softmax_logits(logits: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let data = logits.as_f32()?;
let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = data.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = probs.iter().sum();
for p in &mut probs {
*p /= sum;
}
Tensor::from_f32(&probs, logits.shape().to_vec())
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
}
fn get_token_prob(probs: &Tensor, token: u32) -> f32 {
probs
.as_f32()
.map(|data| data.get(token as usize).copied().unwrap_or(0.0))
.unwrap_or(0.0)
}
fn sample_adjusted_distribution(
target_probs: &Tensor,
draft_probs: &Tensor,
sampler: &mut Sampler,
context: &[u32],
) -> u32 {
let target_data = match target_probs.as_f32() {
Ok(d) => d,
Err(_) => return 0,
};
let draft_data = match draft_probs.as_f32() {
Ok(d) => d,
Err(_) => return 0,
};
let mut adjusted: Vec<f32> = target_data
.iter()
.zip(draft_data.iter())
.map(|(&t, &d)| (t - d).max(0.0))
.collect();
let sum: f32 = adjusted.iter().sum();
if sum > 0.0 {
for p in &mut adjusted {
*p /= sum;
}
} else {
adjusted = target_data.to_vec();
}
let logits: Vec<f32> = adjusted.iter().map(|&p| (p + 1e-10).ln()).collect();
let logits_tensor =
Tensor::from_f32(&logits, target_probs.shape().to_vec()).unwrap_or_else(|_| {
Tensor::zeros(target_probs.shape().to_vec(), crate::tensor::DType::F32)
});
sampler.sample(&logits_tensor, context)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_speculative_config_default() {
let config = SpeculativeConfig::default();
assert_eq!(config.num_speculative, 4);
assert!((config.draft_temperature - 0.8).abs() < 0.01);
}
#[test]
fn test_speculative_stats() {
let mut stats = SpeculativeStats::default();
stats.total_tokens = 100;
stats.accepted_tokens = 75;
stats.rejected_tokens = 25;
stats.batches = 20;
assert!((stats.acceptance_rate() - 0.75).abs() < 0.01);
assert!((stats.avg_accepted_per_batch() - 3.75).abs() < 0.01);
}
#[test]
fn test_speculative_decoder_creation() {
let config = SpeculativeConfig {
num_speculative: 6,
draft_temperature: 0.5,
target_temperature: 0.7,
};
let decoder = SpeculativeDecoder::new(config);
assert_eq!(decoder.stats().total_tokens, 0);
}
}