use std::time::Instant;
use super::draft::{DraftModel, Xorshift64};
use super::types::{DecodingStats, SpeculativeConfig, TokenDistribution, VerificationResult};
use super::verifier::{SpeculativeVerifier, TargetModel};
pub struct SpeculativeDecoder<D: DraftModel, T: TargetModel> {
draft: D,
target: T,
config: SpeculativeConfig,
verifier: SpeculativeVerifier,
rng: Xorshift64,
}
impl<D: DraftModel, T: TargetModel> SpeculativeDecoder<D, T> {
pub fn new(draft: D, target: T, config: SpeculativeConfig, seed: u64) -> Self {
Self {
draft,
target,
config,
verifier: SpeculativeVerifier::new(seed),
rng: Xorshift64::new(seed.wrapping_add(1)),
}
}
pub fn decode(&mut self, prompt: &[usize]) -> (Vec<usize>, DecodingStats) {
let start = Instant::now();
let mut output: Vec<usize> = prompt.to_vec();
let mut stats = DecodingStats::new();
let mut current_draft_length = self.config.draft_length;
let mut step_count: usize = 0;
let mut rolling_accepted: usize = 0;
let mut rolling_drafted: usize = 0;
while output.len() < prompt.len() + self.config.max_tokens {
let draft_result = self.draft.generate_draft(&output, current_draft_length);
let draft_len = draft_result.len();
if draft_len == 0 {
break;
}
let draft_tokens: Vec<usize> = draft_result.iter().map(|(t, _)| *t).collect();
let draft_probs: Vec<f64> = draft_result.iter().map(|(_, p)| *p).collect();
stats.draft_tokens += draft_len;
rolling_drafted += draft_len;
let (target_probs_per_pos, draft_probs_per_pos) =
self.compute_distributions(&output, &draft_tokens);
let vr: VerificationResult = self.verifier.verify_draft(
&draft_tokens,
&draft_probs,
&target_probs_per_pos,
&draft_probs_per_pos,
);
let accepted_count = vr.num_accepted();
stats.accepted_tokens += if vr.all_accepted() {
accepted_count
} else {
accepted_count.saturating_sub(1)
};
rolling_accepted += stats.accepted_tokens;
for &token in &vr.accepted_tokens {
if output.len() >= prompt.len() + self.config.max_tokens {
break;
}
output.push(token);
}
if vr.all_accepted() && output.len() < prompt.len() + self.config.max_tokens {
if let Some(bonus) = self.sample_from_target(&output) {
output.push(bonus);
}
}
step_count += 1;
if self.config.adaptive_draft && rolling_drafted > 0 {
let rate = rolling_accepted as f64 / rolling_drafted as f64;
current_draft_length =
adapt_draft_length(current_draft_length, rate, self.config.draft_length);
}
}
let max_len = prompt.len() + self.config.max_tokens;
if output.len() > max_len {
output.truncate(max_len);
}
let elapsed = start.elapsed();
let generated = output.len().saturating_sub(prompt.len());
stats.total_tokens = generated;
stats.wall_time_ms = elapsed.as_secs_f64() * 1000.0;
stats.tokens_per_step = if step_count > 0 {
generated as f64 / step_count as f64
} else {
0.0
};
(output, stats)
}
fn compute_distributions(
&mut self,
context: &[usize],
draft_tokens: &[usize],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
let vocab_size = self.target.vocab_size();
let mut target_dists = Vec::with_capacity(draft_tokens.len());
let mut draft_dists = Vec::with_capacity(draft_tokens.len());
let mut extended_ctx = context.to_vec();
for &token in draft_tokens {
let log_probs = self.target.log_probs(&extended_ctx);
let target_dist = log_probs_to_probs(&log_probs, self.config.temperature);
let target_dist = if self.config.top_k > 0 && self.config.top_k < vocab_size {
if let Some(dist) = TokenDistribution::from_probs(target_dist) {
if let Some(filtered) = dist.with_top_k(self.config.top_k) {
filtered.probs().to_vec()
} else {
dist.probs().to_vec()
}
} else {
vec![1.0 / vocab_size as f64; vocab_size]
}
} else {
target_dist
};
let draft_for_pos = self.draft.generate_draft(&extended_ctx, 1);
let mut draft_dist = vec![0.0; vocab_size];
if let Some((_, p)) = draft_for_pos.first() {
let remaining = 1.0 - p;
let uniform_part = if vocab_size > 1 {
remaining / (vocab_size - 1) as f64
} else {
0.0
};
draft_dist.fill(uniform_part);
if let Some((tok, _)) = draft_for_pos.first() {
if *tok < vocab_size {
draft_dist[*tok] = *p;
}
}
} else {
let uniform = 1.0 / vocab_size as f64;
draft_dist.fill(uniform);
}
target_dists.push(target_dist);
draft_dists.push(draft_dist);
extended_ctx.push(token);
}
(target_dists, draft_dists)
}
fn sample_from_target(&mut self, context: &[usize]) -> Option<usize> {
let log_probs = self.target.log_probs(context);
let probs = log_probs_to_probs(&log_probs, self.config.temperature);
let dist = TokenDistribution::from_probs(probs)?;
let u = self.rng.next_f64();
Some(dist.sample_with_uniform(u))
}
pub fn config(&self) -> &SpeculativeConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut SpeculativeConfig {
&mut self.config
}
}
fn log_probs_to_probs(log_probs: &[f64], temperature: f64) -> Vec<f64> {
if log_probs.is_empty() {
return Vec::new();
}
let temp = if temperature <= 0.0 { 1.0 } else { temperature };
let scaled: Vec<f64> = log_probs.iter().map(|&lp| lp / temp).collect();
let max_val = scaled.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scaled.iter().map(|&s| (s - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum <= 0.0 || sum.is_nan() {
let uniform = 1.0 / log_probs.len() as f64;
return vec![uniform; log_probs.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
fn adapt_draft_length(current: usize, acceptance_rate: f64, initial: usize) -> usize {
let max_draft = initial * 2;
if acceptance_rate > 0.8 {
(current + 1).min(max_draft)
} else if acceptance_rate < 0.4 {
if current > 1 {
current - 1
} else {
1
}
} else {
current
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::speculative::draft::UniformDraftModel;
struct FixedTarget {
probs: Vec<f64>,
}
impl FixedTarget {
fn new(probs: Vec<f64>) -> Self {
Self { probs }
}
}
impl TargetModel for FixedTarget {
fn log_probs(&self, _context: &[usize]) -> Vec<f64> {
self.probs
.iter()
.map(|&p| if p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
.collect()
}
fn vocab_size(&self) -> usize {
self.probs.len()
}
}
#[test]
fn test_decoder_produces_output() {
let vocab = 10;
let draft = UniformDraftModel::new(vocab, 42).expect("test: uniform draft model");
let target = FixedTarget::new(vec![0.1; vocab]);
let config = SpeculativeConfig {
draft_length: 3,
max_tokens: 20,
..Default::default()
};
let mut decoder = SpeculativeDecoder::new(draft, target, config, 42);
let prompt = vec![0, 1, 2];
let (output, stats) = decoder.decode(&prompt);
assert_eq!(&output[..3], &[0, 1, 2]);
assert!(output.len() > 3, "should generate beyond prompt");
assert!(stats.total_tokens > 0);
assert!(stats.wall_time_ms >= 0.0);
}
#[test]
fn test_decoder_respects_max_tokens() {
let vocab = 5;
let draft = UniformDraftModel::new(vocab, 42).expect("test: uniform draft model");
let target = FixedTarget::new(vec![0.2; vocab]);
let max_tokens = 10;
let config = SpeculativeConfig {
draft_length: 4,
max_tokens,
..Default::default()
};
let mut decoder = SpeculativeDecoder::new(draft, target, config, 42);
let prompt = vec![0];
let (output, stats) = decoder.decode(&prompt);
assert!(
output.len() <= prompt.len() + max_tokens,
"output {} exceeds prompt {} + max_tokens {}",
output.len(),
prompt.len(),
max_tokens
);
assert_eq!(stats.total_tokens, output.len() - prompt.len());
}
#[test]
fn test_decoder_stats_tracking() {
let vocab = 5;
let draft = UniformDraftModel::new(vocab, 123).expect("test: uniform draft model");
let target = FixedTarget::new(vec![0.2; vocab]);
let config = SpeculativeConfig {
draft_length: 2,
max_tokens: 8,
..Default::default()
};
let mut decoder = SpeculativeDecoder::new(draft, target, config, 123);
let (_, stats) = decoder.decode(&[0]);
assert!(stats.draft_tokens > 0);
assert!(stats.total_tokens > 0);
assert!(stats.tokens_per_step > 0.0);
assert!(stats.wall_time_ms >= 0.0);
}
#[test]
fn test_adaptive_draft_length_increases() {
let result = adapt_draft_length(4, 0.9, 4);
assert_eq!(result, 5);
}
#[test]
fn test_adaptive_draft_length_decreases() {
let result = adapt_draft_length(4, 0.3, 4);
assert_eq!(result, 3);
}
#[test]
fn test_adaptive_draft_length_stays() {
let result = adapt_draft_length(4, 0.6, 4);
assert_eq!(result, 4);
}
#[test]
fn test_adaptive_draft_length_floor() {
let result = adapt_draft_length(1, 0.1, 4);
assert_eq!(result, 1);
}
#[test]
fn test_adaptive_draft_length_ceiling() {
let result = adapt_draft_length(8, 0.95, 4);
assert_eq!(result, 8); }
#[test]
fn test_decoder_with_adaptive_enabled() {
let vocab = 5;
let draft = UniformDraftModel::new(vocab, 42).expect("test: uniform draft model");
let target = FixedTarget::new(vec![0.2; vocab]);
let config = SpeculativeConfig {
draft_length: 3,
max_tokens: 15,
adaptive_draft: true,
..Default::default()
};
let mut decoder = SpeculativeDecoder::new(draft, target, config, 42);
let (output, _stats) = decoder.decode(&[0]);
assert!(output.len() > 1);
assert_eq!(output[0], 0);
}
#[test]
fn test_log_probs_to_probs() {
let log_probs = vec![0.0_f64.ln(), 0.0_f64.ln()]; let probs = log_probs_to_probs(&[f64::NEG_INFINITY; 3], 1.0);
assert_eq!(probs.len(), 3);
for &p in &probs {
assert!((p - 1.0 / 3.0).abs() < 1e-10);
}
let _ = log_probs;
let probs = log_probs_to_probs(&[0.0, 0.0, 0.0], 1.0); for &p in &probs {
assert!((p - 1.0 / 3.0).abs() < 1e-10);
}
}
#[test]
fn test_log_probs_to_probs_with_temperature() {
let log_probs = vec![-1.0, 0.0, -2.0]; let probs = log_probs_to_probs(&log_probs, 0.1);
assert!(probs[1] > 0.9, "low temp should sharpen: {:.4}", probs[1]);
}
#[test]
fn test_decoder_empty_prompt() {
let vocab = 5;
let draft = UniformDraftModel::new(vocab, 42).expect("test: uniform draft model");
let target = FixedTarget::new(vec![0.2; vocab]);
let config = SpeculativeConfig {
draft_length: 2,
max_tokens: 5,
..Default::default()
};
let mut decoder = SpeculativeDecoder::new(draft, target, config, 42);
let (output, stats) = decoder.decode(&[]);
assert!(output.len() <= 5);
assert_eq!(stats.total_tokens, output.len());
}
#[test]
fn test_decoding_stats_default() {
let stats = DecodingStats::default();
assert_eq!(stats.total_tokens, 0);
assert_eq!(stats.draft_tokens, 0);
assert_eq!(stats.accepted_tokens, 0);
assert!((stats.acceptance_rate() - 0.0).abs() < 1e-10);
assert!((stats.throughput() - 0.0).abs() < 1e-10);
}
}