use super::types::{SpeculativeConfig, SpeculativeResult};
pub struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 {
0xDEAD_BEEF_CAFE_1234
} else {
seed
},
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
}
pub struct TokenDist {
pub logits: Vec<f64>,
}
impl TokenDist {
pub fn new(logits: Vec<f64>) -> Self {
Self { logits }
}
pub fn softmax(&self) -> Vec<f64> {
softmax_with_temperature(&self.logits, 1.0)
}
pub fn sample(&self, temperature: f64, rng: &mut Xorshift64) -> u32 {
let probs = softmax_with_temperature(&self.logits, temperature);
categorical_sample(&probs, rng.next_f64())
}
pub fn sample_top_p(&self, temperature: f64, top_p: f64, rng: &mut Xorshift64) -> u32 {
let probs = softmax_with_temperature(&self.logits, temperature);
let filtered = apply_top_p(&probs, top_p);
categorical_sample(&filtered, rng.next_f64())
}
pub fn prob(&self, token: u32, temperature: f64) -> f64 {
let probs = softmax_with_temperature(&self.logits, temperature);
probs.get(token as usize).copied().unwrap_or(0.0)
}
}
fn softmax_with_temperature(logits: &[f64], temperature: f64) -> Vec<f64> {
if logits.is_empty() {
return Vec::new();
}
let temp = if temperature <= 0.0 { 1.0 } else { temperature };
let scaled: Vec<f64> = logits.iter().map(|&l| l / temp).collect();
let max_val = scaled.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scaled.iter().map(|&s| (s - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum <= 0.0 || sum.is_nan() {
let u = 1.0 / logits.len() as f64;
return vec![u; logits.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
fn apply_top_p(probs: &[f64], top_p: f64) -> Vec<f64> {
if top_p >= 1.0 || probs.is_empty() {
return probs.to_vec();
}
let mut order: Vec<usize> = (0..probs.len()).collect();
order.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cumulative = 0.0;
let mut mask = vec![false; probs.len()];
for &idx in &order {
mask[idx] = true;
cumulative += probs[idx];
if cumulative >= top_p {
break;
}
}
let mut filtered: Vec<f64> = probs
.iter()
.enumerate()
.map(|(i, &p)| if mask[i] { p } else { 0.0 })
.collect();
let sum: f64 = filtered.iter().sum();
if sum > 0.0 {
for p in &mut filtered {
*p /= sum;
}
}
filtered
}
fn categorical_sample(probs: &[f64], u: f64) -> u32 {
let u = u.clamp(0.0, 1.0 - f64::EPSILON);
let mut cumulative = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if u < cumulative {
return i as u32;
}
}
probs.len().saturating_sub(1) as u32
}
fn residual_sample(target_probs: &[f64], draft_probs: &[f64], rng: &mut Xorshift64) -> u32 {
let len = target_probs.len().min(draft_probs.len());
if len == 0 {
return 0;
}
let residual: Vec<f64> = (0..len)
.map(|i| {
let d = target_probs[i] - draft_probs[i];
if d > 0.0 {
d
} else {
0.0
}
})
.collect();
let sum: f64 = residual.iter().sum();
if sum <= 1e-15 {
let target_sum: f64 = target_probs.iter().take(len).sum();
if target_sum <= 0.0 {
return 0;
}
let u = rng.next_f64() * target_sum;
let mut cum = 0.0;
for (i, &p) in target_probs.iter().take(len).enumerate() {
cum += p;
if u < cum {
return i as u32;
}
}
return len.saturating_sub(1) as u32;
}
let u = rng.next_f64() * sum;
let mut cum = 0.0;
for (i, &r) in residual.iter().enumerate() {
cum += r;
if u < cum {
return i as u32;
}
}
len.saturating_sub(1) as u32
}
pub struct SpeculativeDecoder;
impl SpeculativeDecoder {
pub fn rejection_sampling_step(
draft_logits: &[Vec<f64>],
target_logits: &[Vec<f64>],
draft_tokens: &[u32],
config: &SpeculativeConfig,
) -> (Vec<u32>, u32) {
let n = draft_tokens
.len()
.min(draft_logits.len())
.min(target_logits.len());
if n == 0 {
return (Vec::new(), 0);
}
let seed: u64 = draft_tokens.first().copied().unwrap_or(0) as u64
^ (draft_logits.first().map_or(0, |v| v.len()) as u64).wrapping_mul(0x9E3779B97F4A7C15);
let mut rng = Xorshift64::new(seed);
let temp = config.temperature;
let top_p = config.top_p;
let mut accepted: Vec<u32> = Vec::with_capacity(n);
for i in 0..n {
let draft_token = draft_tokens[i];
let d_logits = &draft_logits[i];
let t_logits = &target_logits[i];
let draft_probs = softmax_with_temperature(d_logits, temp);
let target_probs = softmax_with_temperature(t_logits, temp);
let draft_p_filtered = apply_top_p(&draft_probs, top_p);
let target_p_filtered = apply_top_p(&target_probs, top_p);
let q = draft_p_filtered
.get(draft_token as usize)
.copied()
.unwrap_or(0.0);
let p = target_p_filtered
.get(draft_token as usize)
.copied()
.unwrap_or(0.0);
let accept = Self::accept_or_reject(q, p, config.acceptance_threshold, &mut rng);
if accept {
accepted.push(draft_token);
} else {
let correction = residual_sample(&target_p_filtered, &draft_p_filtered, &mut rng);
return (accepted, correction);
}
}
let last_target = &target_logits[n - 1];
let target_probs = softmax_with_temperature(last_target, temp);
let target_filtered = apply_top_p(&target_probs, top_p);
let bonus = categorical_sample(&target_filtered, rng.next_f64());
(accepted, bonus)
}
pub fn decode<D, T>(
context: &[u32],
draft_fn: D,
target_fn: T,
config: &SpeculativeConfig,
) -> SpeculativeResult
where
D: Fn(&[u32]) -> Vec<(u32, Vec<f64>)>,
T: Fn(&[u32], &[u32]) -> Vec<Vec<f64>>,
{
let mut current_ctx: Vec<u32> = context.to_vec();
let mut output_tokens: Vec<u32> = Vec::new();
let mut n_draft_total = 0_usize;
let mut n_accepted_total = 0_usize;
let mut n_verification_calls = 0_usize;
while output_tokens.len() < config.max_tokens {
let draft_result = draft_fn(¤t_ctx);
if draft_result.is_empty() {
break;
}
let remaining = config.max_tokens - output_tokens.len();
let draft_len = draft_result.len().min(remaining);
let draft_tokens: Vec<u32> =
draft_result[..draft_len].iter().map(|(t, _)| *t).collect();
let draft_logits: Vec<Vec<f64>> = draft_result[..draft_len]
.iter()
.map(|(_, l)| l.clone())
.collect();
n_draft_total += draft_len;
let target_logits = target_fn(¤t_ctx, &draft_tokens);
n_verification_calls += 1;
let (accepted, correction) =
Self::rejection_sampling_step(&draft_logits, &target_logits, &draft_tokens, config);
let n_acc = accepted.len();
n_accepted_total += n_acc;
for token in accepted {
if output_tokens.len() >= config.max_tokens {
break;
}
current_ctx.push(token);
output_tokens.push(token);
}
if output_tokens.len() < config.max_tokens {
current_ctx.push(correction);
output_tokens.push(correction);
}
if output_tokens.len() >= config.max_tokens {
break;
}
}
SpeculativeResult::new(
output_tokens,
n_draft_total,
n_accepted_total,
n_verification_calls,
)
}
fn accept_or_reject(
draft_prob: f64,
target_prob: f64,
threshold: f64,
rng: &mut Xorshift64,
) -> bool {
if draft_prob <= threshold {
return true;
}
if target_prob >= draft_prob {
return true;
}
rng.next_f64() < target_prob / draft_prob
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::speculative::SpeculativeConfig;
#[test]
fn tokendist_softmax_sums_to_one() {
let d = TokenDist::new(vec![1.0, 2.0, 0.5, -1.0]);
let probs = d.softmax();
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "sum={sum}");
}
#[test]
fn tokendist_softmax_nonnegative() {
let d = TokenDist::new(vec![-100.0, 0.0, 100.0]);
for &p in &d.softmax() {
assert!(p >= 0.0);
}
}
#[test]
fn tokendist_sample_returns_valid_index() {
let vocab = 8;
let d = TokenDist::new(vec![1.0; vocab]);
let mut rng = Xorshift64::new(12345);
for _ in 0..50 {
let tok = d.sample(1.0, &mut rng);
assert!((tok as usize) < vocab);
}
}
#[test]
fn tokendist_prob_out_of_range_is_zero() {
let d = TokenDist::new(vec![1.0, 2.0]);
assert!((d.prob(99, 1.0) - 0.0).abs() < 1e-12);
}
#[test]
fn step_all_accepted_when_draft_equals_target() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let draft_logits = vec![logits.clone(), logits.clone(), logits.clone()];
let target_logits = draft_logits.clone();
let draft_tokens = vec![3u32, 3u32, 3u32];
let cfg = SpeculativeConfig::default();
let (accepted, _correction) = SpeculativeDecoder::rejection_sampling_step(
&draft_logits,
&target_logits,
&draft_tokens,
&cfg,
);
assert_eq!(
accepted.len(),
draft_tokens.len(),
"all tokens should be accepted when draft==target"
);
}
#[test]
fn step_some_rejected_when_distributions_differ() {
let mut any_rejected = false;
for trial in 0u64..50 {
let draft_logits = vec![
vec![100.0, -100.0, -100.0, -100.0],
vec![100.0, -100.0, -100.0, -100.0],
];
let target_logits = vec![
vec![-100.0, -100.0, -100.0, 100.0],
vec![-100.0, -100.0, -100.0, 100.0],
];
let draft_tokens = vec![0u32, 0u32];
let cfg = SpeculativeConfig {
draft_steps: 2,
acceptance_threshold: trial as f64 * 0.0, ..Default::default()
};
let (accepted, _) = SpeculativeDecoder::rejection_sampling_step(
&draft_logits,
&target_logits,
&draft_tokens,
&cfg,
);
if accepted.len() < draft_tokens.len() {
any_rejected = true;
break;
}
}
assert!(any_rejected, "expected at least one rejection in 50 trials");
}
#[test]
fn step_empty_input_returns_empty() {
let cfg = SpeculativeConfig::default();
let (acc, _corr) = SpeculativeDecoder::rejection_sampling_step(&[], &[], &[], &cfg);
assert!(acc.is_empty());
}
#[test]
fn decode_acceptance_rate_leq_one() {
let logits = vec![1.0f64; 4];
let cfg = SpeculativeConfig {
draft_steps: 3,
max_tokens: 12,
..Default::default()
};
let result = SpeculativeDecoder::decode(
&[0u32, 1u32],
|_ctx| {
vec![
(0, logits.clone()),
(1, logits.clone()),
(2, logits.clone()),
]
},
|_ctx, draft| draft.iter().map(|_| logits.clone()).collect(),
&cfg,
);
assert!(
result.acceptance_rate <= 1.0 + 1e-10,
"acceptance_rate={:.4}",
result.acceptance_rate
);
}
#[test]
fn decode_perfect_draft_high_acceptance() {
let logits = vec![0.0f64; 4]; let cfg = SpeculativeConfig {
draft_steps: 2,
max_tokens: 10,
..Default::default()
};
let result = SpeculativeDecoder::decode(
&[0u32],
|_ctx| vec![(0, logits.clone()), (1, logits.clone())],
|_ctx, draft| draft.iter().map(|_| logits.clone()).collect(),
&cfg,
);
assert!(
result.acceptance_rate >= 0.99,
"expected high acceptance, got {:.4}",
result.acceptance_rate
);
assert_eq!(result.accepted_tokens.len(), result.accepted_tokens.len()); }
#[test]
fn speculative_config_default_works() {
let cfg = SpeculativeConfig::default();
assert_eq!(cfg.draft_steps, 4);
}
}