use crate::engine::InferenceEngine;
use crate::sampling::SamplingParams;
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub lookahead: usize,
pub acceptance_threshold: f32,
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
lookahead: 5,
acceptance_threshold: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct SpeculativeStep {
pub draft_tokens: Vec<u32>,
pub accepted_tokens: Vec<u32>,
pub acceptance_rate: f32,
}
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
let state = if seed == 0 { 0xdeadbeef_cafebabe } else { seed };
Self { state }
}
fn next_u64(&mut self) -> u64 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
self.state
}
fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
}
}
pub struct SpeculativeDecoder<'a> {
pub draft_engine: InferenceEngine<'a>,
pub config: SpeculativeConfig,
pub total_steps: u64,
pub total_draft_tokens: u64,
pub total_accepted_tokens: u64,
#[allow(dead_code)]
rng: Xorshift64,
}
impl<'a> SpeculativeDecoder<'a> {
pub fn new(draft_engine: InferenceEngine<'a>, config: SpeculativeConfig) -> Self {
Self {
draft_engine,
config,
total_steps: 0,
total_draft_tokens: 0,
total_accepted_tokens: 0,
rng: Xorshift64::new(0xfeed1234_5678abcd),
}
}
pub fn draft(&mut self, context: &[u32], _params: &SamplingParams) -> Vec<u32> {
let k = self.config.lookahead;
let mut draft_tokens = Vec::with_capacity(k);
let mut current_context: Vec<u32> = context.to_vec();
for _ in 0..k {
match self.draft_engine.generate(¤t_context, 1) {
Ok(generated) if !generated.is_empty() => {
let token = generated[0];
draft_tokens.push(token);
current_context.push(token);
}
_ => {
break;
}
}
}
draft_tokens
}
pub fn verify(
&self,
draft_tokens: &[u32],
target_logits: &[Vec<f32>],
_params: &SamplingParams,
) -> Vec<u32> {
let mut accepted = Vec::with_capacity(draft_tokens.len());
let mut local_rng = Xorshift64::new(
self.total_steps
.wrapping_mul(6364136223846793005)
.wrapping_add(0xabcdef01),
);
for (i, &token) in draft_tokens.iter().enumerate() {
let logits = match target_logits.get(i) {
Some(l) => l,
None => break,
};
if logits.is_empty() {
break;
}
let target_probs = softmax(logits);
let target_prob = if (token as usize) < target_probs.len() {
target_probs[token as usize]
} else {
0.0
};
let vocab_size = logits.len() as f32;
let draft_prob = (1.0 / vocab_size).max(1e-9);
let rng_sample = local_rng.next_f32();
let threshold = self.config.acceptance_threshold;
if Self::should_accept(draft_prob, target_prob, threshold, rng_sample) {
accepted.push(token);
} else {
break;
}
}
accepted
}
pub fn step(
&mut self,
context: &[u32],
target_logits: &[Vec<f32>],
params: &SamplingParams,
) -> SpeculativeStep {
let draft_tokens = self.draft(context, params);
let n_drafted = draft_tokens.len();
let accepted_tokens = self.verify(&draft_tokens, target_logits, params);
let n_accepted = accepted_tokens.len();
self.total_steps += 1;
self.total_draft_tokens += n_drafted as u64;
self.total_accepted_tokens += n_accepted as u64;
let acceptance_rate = if n_drafted > 0 {
n_accepted as f32 / n_drafted as f32
} else {
0.0
};
SpeculativeStep {
draft_tokens,
accepted_tokens,
acceptance_rate,
}
}
pub fn generate_speculative(
&mut self,
prompt_tokens: &[u32],
max_tokens: usize,
params: &SamplingParams,
) -> Vec<u32> {
let mut output: Vec<u32> = Vec::with_capacity(max_tokens);
let mut context: Vec<u32> = prompt_tokens.to_vec();
while output.len() < max_tokens {
let remaining = max_tokens - output.len();
let effective_lookahead = self.config.lookahead.min(remaining);
let vocab_size = 32000usize; let target_logits: Vec<Vec<f32>> = (0..effective_lookahead)
.map(|step_idx| {
let peak_token =
(context.last().copied().unwrap_or(0) as usize + step_idx + 1) % vocab_size;
let mut logits = vec![0.0f32; vocab_size];
logits[peak_token] = 10.0;
for (i, l) in logits.iter_mut().enumerate() {
if i != peak_token {
*l = -2.0;
}
}
logits
})
.collect();
let step_result = self.step(&context, &target_logits, params);
if step_result.accepted_tokens.is_empty() {
match self.draft_engine.generate(&context, 1) {
Ok(t) if !t.is_empty() => {
let token = t[0];
output.push(token);
context.push(token);
}
_ => break,
}
} else {
let to_take = step_result.accepted_tokens.len().min(remaining);
for &tok in step_result.accepted_tokens[..to_take].iter() {
output.push(tok);
context.push(tok);
if output.len() >= max_tokens {
break;
}
}
}
if context.len() > prompt_tokens.len() + max_tokens + self.config.lookahead {
break;
}
}
output
}
pub fn acceptance_rate(&self) -> f32 {
if self.total_draft_tokens == 0 {
return 0.0;
}
self.total_accepted_tokens as f32 / self.total_draft_tokens as f32
}
pub fn speedup_estimate(&self) -> f32 {
if self.total_steps == 0 {
return 1.0;
}
let avg_accepted = self.total_accepted_tokens as f32 / self.total_steps as f32;
avg_accepted.max(1.0)
}
pub fn reset_stats(&mut self) {
self.total_steps = 0;
self.total_draft_tokens = 0;
self.total_accepted_tokens = 0;
}
fn should_accept(draft_prob: f32, target_prob: f32, threshold: f32, rng_sample: f32) -> bool {
if target_prob >= draft_prob {
true
} else {
let accept_prob = (target_prob / draft_prob).max(0.0);
let effective_threshold = accept_prob - threshold;
rng_sample < effective_threshold
}
}
}
fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return vec![];
}
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&l| (l - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum < 1e-30 {
let n = logits.len() as f32;
return vec![1.0 / n; logits.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use oxibonsai_core::config::Qwen3Config;
fn make_decoder(lookahead: usize) -> SpeculativeDecoder<'static> {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let spec_config = SpeculativeConfig {
lookahead,
acceptance_threshold: 0.0,
};
SpeculativeDecoder::new(engine, spec_config)
}
fn make_peaked_logits(
vocab_size: usize,
peak_token: usize,
n_positions: usize,
) -> Vec<Vec<f32>> {
(0..n_positions)
.map(|_| {
let mut logits = vec![-5.0f32; vocab_size];
if peak_token < vocab_size {
logits[peak_token] = 10.0;
}
logits
})
.collect()
}
#[test]
fn test_speculative_config_defaults() {
let cfg = SpeculativeConfig::default();
assert_eq!(cfg.lookahead, 5, "default lookahead should be 5");
assert!(
(cfg.acceptance_threshold - 0.0).abs() < f32::EPSILON,
"default threshold should be 0.0"
);
}
#[test]
fn test_draft_generates_lookahead_tokens() {
let mut decoder = make_decoder(3);
let context = vec![1u32, 2, 3];
let params = SamplingParams::default();
let draft = decoder.draft(&context, ¶ms);
assert!(
draft.len() <= 3,
"draft should not exceed lookahead=3, got {}",
draft.len()
);
}
#[test]
fn test_verify_accepts_high_probability_tokens() {
let decoder = make_decoder(5);
let params = SamplingParams::default();
let vocab_size = 100;
let draft_tokens = vec![42u32];
let target_logits = make_peaked_logits(vocab_size, 42, 1);
let accepted = decoder.verify(&draft_tokens, &target_logits, ¶ms);
assert_eq!(
accepted.len(),
1,
"high-probability token should be accepted"
);
assert_eq!(accepted[0], 42);
}
#[test]
fn test_verify_rejects_low_probability_tokens() {
let decoder = make_decoder(5);
let params = SamplingParams::default();
let vocab_size = 1000;
let draft_tokens = vec![500u32];
let mut logits = vec![-10.0f32; vocab_size];
logits[0] = 20.0; let target_logits = vec![logits];
let mut rejections = 0;
for _ in 0..20 {
let accepted = decoder.verify(&draft_tokens, &target_logits, ¶ms);
if accepted.is_empty() {
rejections += 1;
}
}
assert!(
rejections > 0,
"low-probability token should be rejected at least sometimes"
);
}
#[test]
fn test_acceptance_rate_zero_at_start() {
let decoder = make_decoder(5);
assert!(
(decoder.acceptance_rate() - 0.0).abs() < f32::EPSILON,
"acceptance rate must be 0.0 before any steps"
);
assert_eq!(decoder.total_steps, 0);
assert_eq!(decoder.total_draft_tokens, 0);
assert_eq!(decoder.total_accepted_tokens, 0);
}
#[test]
fn test_acceptance_rate_updates_after_step() {
let mut decoder = make_decoder(4);
let params = SamplingParams::default();
let context = vec![1u32, 2, 3];
let vocab_size = 32usize;
let target_logits = make_peaked_logits(vocab_size, 5, 4);
let step = decoder.step(&context, &target_logits, ¶ms);
assert_eq!(decoder.total_steps, 1, "one step should have been recorded");
assert_eq!(
decoder.total_draft_tokens,
step.draft_tokens.len() as u64,
"draft token count should match"
);
assert!(
decoder.total_accepted_tokens <= decoder.total_draft_tokens,
"accepted cannot exceed drafted"
);
}
#[test]
fn test_generate_speculative_returns_tokens() {
let mut decoder = make_decoder(3);
let params = SamplingParams::default();
let prompt = vec![1u32, 2, 3];
let output = decoder.generate_speculative(&prompt, 5, ¶ms);
assert!(
output.len() <= 5,
"output should not exceed max_tokens=5, got {}",
output.len()
);
}
#[test]
fn test_should_accept_target_above_draft() {
assert!(
SpeculativeDecoder::should_accept(0.1, 0.9, 0.0, 0.99),
"target > draft: must accept even with rng_sample near 1.0"
);
assert!(
SpeculativeDecoder::should_accept(0.05, 0.5, 0.0, 0.0),
"target > draft: must accept with rng_sample=0.0"
);
}
#[test]
fn test_should_accept_target_below_draft_probabilistic() {
assert!(
SpeculativeDecoder::should_accept(1.0, 0.1, 0.0, 0.05),
"rng_sample=0.05 < accept_prob=0.1, should accept"
);
assert!(
!SpeculativeDecoder::should_accept(1.0, 0.1, 0.0, 0.5),
"rng_sample=0.5 >= accept_prob=0.1, should reject"
);
}
#[test]
fn test_speedup_estimate_below_lookahead() {
let mut decoder = make_decoder(5);
assert!(
(decoder.speedup_estimate() - 1.0).abs() < f32::EPSILON,
"initial speedup should be 1.0"
);
decoder.total_steps = 10;
decoder.total_draft_tokens = 30;
decoder.total_accepted_tokens = 15;
let speedup = decoder.speedup_estimate();
assert!(
(speedup - 1.5).abs() < 1e-4,
"speedup should be 1.5 (avg accepted per step), got {speedup}"
);
assert!(
speedup <= decoder.config.lookahead as f32 + 1.0,
"speedup cannot exceed lookahead+1"
);
}
}