use candle_core::{Result, Tensor};
use rand::prelude::*;
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub num_speculative_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub use_rejection_sampling: bool,
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
num_speculative_tokens: 4,
temperature: 1.0,
top_p: 0.9,
use_rejection_sampling: true,
}
}
}
#[derive(Debug)]
pub struct SpeculativeResult {
pub accepted_tokens: Vec<u32>,
pub num_accepted: usize,
pub total_generated: usize,
pub acceptance_rate: f32,
}
pub struct SpeculativeDecoder {
config: SpeculativeConfig,
total_draft_tokens: usize,
total_accepted_tokens: usize,
}
impl SpeculativeDecoder {
pub fn new(config: SpeculativeConfig) -> Self {
Self {
config,
total_draft_tokens: 0,
total_accepted_tokens: 0,
}
}
pub fn verify(
&mut self,
draft_tokens: &[u32],
draft_probs: &[f32],
target_logits: &Tensor,
) -> Result<SpeculativeResult> {
let num_draft = draft_tokens.len();
let mut accepted_tokens = Vec::with_capacity(num_draft + 1);
let mut rng = rand::thread_rng();
let target_probs = self.compute_probs(target_logits)?;
let _vocab_size = target_probs.dim(1)?;
for i in 0..num_draft {
let draft_token = draft_tokens[i] as usize;
let draft_prob = draft_probs[i];
let target_prob_vec: Vec<f32> = target_probs.get(i)?.to_vec1()?;
let target_prob = target_prob_vec[draft_token];
let accept = if self.config.use_rejection_sampling {
let acceptance_prob = (target_prob / draft_prob.max(1e-10)).min(1.0);
rng.gen::<f32>() < acceptance_prob
} else {
let target_token = argmax(&target_prob_vec);
target_token == draft_token
};
if accept {
accepted_tokens.push(draft_tokens[i]);
} else {
if self.config.use_rejection_sampling {
let adjusted_token =
self.sample_adjusted(&target_prob_vec, draft_prob, draft_token, &mut rng);
accepted_tokens.push(adjusted_token);
} else {
accepted_tokens.push(argmax(&target_prob_vec) as u32);
}
break; }
}
if accepted_tokens.len() == num_draft {
let final_probs: Vec<f32> = target_probs.get(num_draft)?.to_vec1()?;
let next_token = self.sample_from_probs(&final_probs, &mut rng);
accepted_tokens.push(next_token);
}
let num_accepted = accepted_tokens.len().saturating_sub(1).min(num_draft);
self.total_draft_tokens += num_draft;
self.total_accepted_tokens += num_accepted;
let acceptance_rate = if num_draft > 0 {
num_accepted as f32 / num_draft as f32
} else {
0.0
};
let total_generated = accepted_tokens.len();
Ok(SpeculativeResult {
accepted_tokens,
num_accepted,
total_generated,
acceptance_rate,
})
}
fn compute_probs(&self, logits: &Tensor) -> Result<Tensor> {
let logits = if self.config.temperature != 1.0 {
(logits / self.config.temperature as f64)?
} else {
logits.clone()
};
let probs = candle_nn::ops::softmax_last_dim(&logits)?;
Ok(probs)
}
fn sample_adjusted(
&self,
target_probs: &[f32],
draft_prob: f32,
draft_token: usize,
rng: &mut impl Rng,
) -> u32 {
let mut adjusted: Vec<f32> = target_probs
.iter()
.enumerate()
.map(|(i, &p)| {
if i == draft_token {
(p - draft_prob).max(0.0)
} else {
p
}
})
.collect();
let sum: f32 = adjusted.iter().sum();
if sum > 1e-10 {
for p in &mut adjusted {
*p /= sum;
}
self.sample_from_probs(&adjusted, rng)
} else {
argmax(target_probs) as u32
}
}
fn sample_from_probs(&self, probs: &[f32], rng: &mut impl Rng) -> u32 {
if self.config.temperature == 0.0 {
return argmax(probs) as u32;
}
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return i as u32;
}
}
(probs.len() - 1) as u32
}
pub fn overall_acceptance_rate(&self) -> f32 {
if self.total_draft_tokens > 0 {
self.total_accepted_tokens as f32 / self.total_draft_tokens as f32
} else {
0.0
}
}
pub fn reset_stats(&mut self) {
self.total_draft_tokens = 0;
self.total_accepted_tokens = 0;
}
pub fn expected_speedup(&self) -> f32 {
let gamma = self.config.num_speculative_tokens as f32;
let alpha = self.overall_acceptance_rate();
let beta = 10.0;
(1.0 + gamma * alpha) / (1.0 + gamma / beta)
}
}
fn argmax(slice: &[f32]) -> usize {
slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_speculative_config_default() {
let config = SpeculativeConfig::default();
assert_eq!(config.num_speculative_tokens, 4);
assert_eq!(config.temperature, 1.0);
}
#[test]
fn test_argmax() {
assert_eq!(argmax(&[0.1, 0.5, 0.3, 0.1]), 1);
assert_eq!(argmax(&[0.9, 0.05, 0.05]), 0);
}
#[test]
fn test_speculative_decoder_creation() {
let decoder = SpeculativeDecoder::new(SpeculativeConfig::default());
assert_eq!(decoder.overall_acceptance_rate(), 0.0);
}
#[test]
fn test_verify_all_accepted() {
let device = Device::Cpu;
let mut decoder = SpeculativeDecoder::new(SpeculativeConfig {
use_rejection_sampling: false, ..Default::default()
});
let draft_tokens = vec![1u32, 2, 3];
let draft_probs = vec![0.9, 0.8, 0.7];
let mut logits_data = vec![0.0f32; 4 * 10];
logits_data[1] = 10.0; logits_data[12] = 10.0; logits_data[23] = 10.0; logits_data[34] = 10.0;
let target_logits = Tensor::from_vec(logits_data, (4, 10), &device).unwrap();
let result = decoder
.verify(&draft_tokens, &draft_probs, &target_logits)
.unwrap();
assert_eq!(result.num_accepted, 3);
assert_eq!(result.accepted_tokens.len(), 4); assert_eq!(result.accepted_tokens[0..3], vec![1, 2, 3]);
}
#[test]
fn test_verify_early_rejection() {
let device = Device::Cpu;
let mut decoder = SpeculativeDecoder::new(SpeculativeConfig {
use_rejection_sampling: false,
..Default::default()
});
let draft_tokens = vec![1u32, 2, 3];
let draft_probs = vec![0.9, 0.8, 0.7];
let mut logits_data = vec![0.0f32; 4 * 10];
logits_data[1] = 10.0; logits_data[15] = 10.0; logits_data[23] = 10.0; logits_data[34] = 10.0;
let target_logits = Tensor::from_vec(logits_data, (4, 10), &device).unwrap();
let result = decoder
.verify(&draft_tokens, &draft_probs, &target_logits)
.unwrap();
assert_eq!(result.num_accepted, 1); assert_eq!(result.accepted_tokens.len(), 2); assert_eq!(result.accepted_tokens[0], 1); assert_eq!(result.accepted_tokens[1], 5); }
#[test]
fn test_expected_speedup() {
let mut decoder = SpeculativeDecoder::new(SpeculativeConfig {
num_speculative_tokens: 4,
..Default::default()
});
decoder.total_draft_tokens = 100;
decoder.total_accepted_tokens = 80;
let speedup = decoder.expected_speedup();
assert!(speedup > 2.5 && speedup < 3.5, "speedup = {}", speedup);
}
}