use crate::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AcceptOutcome {
Accepted,
Rejected,
}
#[derive(Debug, Clone)]
pub struct VanillaConfig {
pub draft_lookahead: usize,
pub temperature: f32,
pub top_p: f32,
}
impl Default for VanillaConfig {
fn default() -> Self {
Self {
draft_lookahead: 4,
temperature: 0.7,
top_p: 0.95,
}
}
}
pub fn modified_rejection_step(q_draft: f32, p_target: f32, u: f32) -> AcceptOutcome {
debug_assert!((0.0..=1.0).contains(&q_draft), "q_draft out of [0,1]");
debug_assert!((0.0..=1.0).contains(&p_target), "p_target out of [0,1]");
debug_assert!((0.0..=1.0).contains(&u), "u out of [0,1]");
if q_draft <= 0.0 {
return AcceptOutcome::Accepted;
}
let ratio = (p_target / q_draft).min(1.0);
if u <= ratio {
AcceptOutcome::Accepted
} else {
AcceptOutcome::Rejected
}
}
pub fn adjusted_distribution(p_target: &[f32], q_draft: &[f32]) -> Result<Vec<f32>> {
if p_target.len() != q_draft.len() {
return Err(crate::Error::Sampling(format!(
"vocab size mismatch: target={}, draft={}",
p_target.len(),
q_draft.len()
)));
}
let mut adj: Vec<f32> = p_target
.iter()
.zip(q_draft.iter())
.map(|(p, q)| (p - q).max(0.0))
.collect();
let sum: f32 = adj.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
return Err(crate::Error::Sampling(format!(
"adjusted distribution has non-positive mass: sum={sum}"
)));
}
for v in adj.iter_mut() {
*v /= sum;
}
Ok(adj)
}
pub fn run_vanilla_sd<T, D, R>(
target: &mut T,
draft: &mut D,
prompt: &[u32],
max_new_tokens: usize,
config: &VanillaConfig,
rng: &mut R,
) -> Result<Vec<u32>>
where
T: crate::model::Decoder + ?Sized,
D: crate::model::Decoder + ?Sized,
R: rand::Rng + ?Sized,
{
use crate::sampling::{sample_from_distribution, softmax_with_temperature};
if target.vocab_size() != draft.vocab_size() {
return Err(crate::Error::Sampling(format!(
"target vocab {} != draft vocab {}",
target.vocab_size(),
draft.vocab_size()
)));
}
target.reset();
draft.reset();
target.observe(prompt)?;
draft.observe(prompt)?;
let mut generated: Vec<u32> = Vec::with_capacity(max_new_tokens);
while generated.len() < max_new_tokens {
let remaining = max_new_tokens - generated.len();
let k = config.draft_lookahead.min(remaining);
if k == 0 {
break;
}
let pre_target_len = target.history_len();
let pre_draft_len = draft.history_len();
let mut draft_tokens: Vec<u32> = Vec::with_capacity(k);
let mut draft_dists: Vec<Vec<f32>> = Vec::with_capacity(k);
for _ in 0..k {
let logits = draft.next_logits()?;
let probs = softmax_with_temperature(&logits, config.temperature)?;
let tok = sample_from_distribution(rng, &probs)? as u32;
draft_tokens.push(tok);
draft_dists.push(probs);
draft.observe(&[tok])?;
}
let target_batched = target.batched_logits(&draft_tokens)?;
debug_assert_eq!(target_batched.len(), k + 1);
let mut commits: Vec<u32> = Vec::with_capacity(k + 1);
let mut rejected = false;
for i in 0..k {
let p_probs = softmax_with_temperature(&target_batched[i], config.temperature)?;
let q_probs = &draft_dists[i];
let token = draft_tokens[i] as usize;
let p = p_probs[token];
let q = q_probs[token];
let u: f32 = rng.gen();
match modified_rejection_step(q, p, u) {
AcceptOutcome::Accepted => {
commits.push(draft_tokens[i]);
if generated.len() + commits.len() >= max_new_tokens {
break;
}
}
AcceptOutcome::Rejected => {
let adj = adjusted_distribution(&p_probs, q_probs)?;
let new_tok = sample_from_distribution(rng, &adj)? as u32;
commits.push(new_tok);
rejected = true;
break;
}
}
}
if !rejected && commits.len() == k && generated.len() + commits.len() < max_new_tokens {
let p_probs = softmax_with_temperature(&target_batched[k], config.temperature)?;
let bonus = sample_from_distribution(rng, &p_probs)? as u32;
commits.push(bonus);
}
let accepted_count = commits.len().min(k);
target.rollback_to(pre_target_len + accepted_count)?;
target.observe(&commits[accepted_count..])?;
draft.rollback_to(pre_draft_len + accepted_count)?;
draft.observe(&commits[accepted_count..])?;
generated.extend_from_slice(&commits);
}
Ok(generated)
}
#[cfg(any(test, feature = "test-util"))]
pub fn run_vanilla_sd_with_mock<R: rand::Rng>(
target: &mut crate::model::mock::MockDecoder,
draft: &mut crate::model::mock::MockDecoder,
prompt: &[u32],
max_new_tokens: usize,
config: &VanillaConfig,
rng: &mut R,
) -> Result<Vec<u32>> {
run_vanilla_sd(target, draft, prompt, max_new_tokens, config, rng)
}
pub fn run_vanilla_sd_with<T, D, R, F>(
target: &mut T,
draft: &mut D,
prompt: &[u32],
opts: &crate::engine::GenerationOptions,
config: &VanillaConfig,
rng: &mut R,
mut on_token: F,
) -> Result<Vec<u32>>
where
T: crate::model::Decoder + ?Sized,
D: crate::model::Decoder + ?Sized,
R: rand::Rng + ?Sized,
F: FnMut(u32) -> bool,
{
use crate::sampling::{sample_from_distribution, softmax_with_temperature};
if target.vocab_size() != draft.vocab_size() {
return Err(crate::Error::Sampling(format!(
"target vocab {} != draft vocab {}",
target.vocab_size(),
draft.vocab_size()
)));
}
target.reset();
draft.reset();
target.observe(prompt)?;
draft.observe(prompt)?;
let max_new_tokens = opts.max_new_tokens;
let mut generated: Vec<u32> = Vec::with_capacity(max_new_tokens);
'rounds: while generated.len() < max_new_tokens {
let remaining = max_new_tokens - generated.len();
let k = config.draft_lookahead.min(remaining);
if k == 0 {
break;
}
let pre_target_len = target.history_len();
let pre_draft_len = draft.history_len();
let mut draft_tokens: Vec<u32> = Vec::with_capacity(k);
let mut draft_dists: Vec<Vec<f32>> = Vec::with_capacity(k);
for _ in 0..k {
let logits = draft.next_logits()?;
let probs = softmax_with_temperature(&logits, config.temperature)?;
let tok = sample_from_distribution(rng, &probs)? as u32;
draft_tokens.push(tok);
draft_dists.push(probs);
draft.observe(&[tok])?;
}
let target_batched = target.batched_logits(&draft_tokens)?;
debug_assert_eq!(target_batched.len(), k + 1);
let mut commits: Vec<u32> = Vec::with_capacity(k + 1);
let mut rejected = false;
for i in 0..k {
let p_probs = softmax_with_temperature(&target_batched[i], config.temperature)?;
let q_probs = &draft_dists[i];
let token = draft_tokens[i] as usize;
let p = p_probs[token];
let q = q_probs[token];
let u: f32 = rng.gen();
match modified_rejection_step(q, p, u) {
AcceptOutcome::Accepted => {
commits.push(draft_tokens[i]);
if generated.len() + commits.len() >= max_new_tokens {
break;
}
}
AcceptOutcome::Rejected => {
let adj = adjusted_distribution(&p_probs, q_probs)?;
let new_tok = sample_from_distribution(rng, &adj)? as u32;
commits.push(new_tok);
rejected = true;
break;
}
}
}
if !rejected && commits.len() == k && generated.len() + commits.len() < max_new_tokens {
let p_probs = softmax_with_temperature(&target_batched[k], config.temperature)?;
let bonus = sample_from_distribution(rng, &p_probs)? as u32;
commits.push(bonus);
}
let mut early_stop_at: Option<usize> = None;
for (i, &tok) in commits.iter().enumerate() {
if !on_token(tok) || opts.stop_tokens.contains(&tok) {
early_stop_at = Some(i + 1);
break;
}
}
if let Some(stop_at) = early_stop_at {
commits.truncate(stop_at);
}
let accepted_count = commits.len().min(k);
target.rollback_to(pre_target_len + accepted_count)?;
target.observe(&commits[accepted_count..])?;
draft.rollback_to(pre_draft_len + accepted_count)?;
draft.observe(&commits[accepted_count..])?;
generated.extend_from_slice(&commits);
if early_stop_at.is_some() {
break 'rounds;
}
}
Ok(generated)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn always_accept_when_target_geq_draft() {
for u_pct in 0..=100 {
let u = u_pct as f32 / 100.0;
assert_eq!(
modified_rejection_step(0.4, 0.8, u),
AcceptOutcome::Accepted
);
}
}
#[test]
fn reject_when_u_above_ratio() {
assert_eq!(
modified_rejection_step(0.8, 0.4, 0.6),
AcceptOutcome::Rejected
);
assert_eq!(
modified_rejection_step(0.8, 0.4, 0.4),
AcceptOutcome::Accepted
);
}
#[test]
fn adjusted_distribution_renormalizes() {
let p = vec![0.5, 0.3, 0.2];
let q = vec![0.1, 0.4, 0.5];
let adj = adjusted_distribution(&p, &q).unwrap();
assert_relative_eq!(adj[0], 1.0, max_relative = 1e-6);
assert_relative_eq!(adj[1], 0.0, max_relative = 1e-6);
assert_relative_eq!(adj[2], 0.0, max_relative = 1e-6);
let sum: f32 = adj.iter().sum();
assert_relative_eq!(sum, 1.0, max_relative = 1e-6);
}
#[test]
fn adjusted_distribution_rejects_zero_mass() {
let p = vec![0.5, 0.5];
let q = vec![0.5, 0.5];
assert!(adjusted_distribution(&p, &q).is_err());
}
#[test]
fn adjusted_distribution_rejects_size_mismatch() {
let p = vec![0.5, 0.5];
let q = vec![1.0];
assert!(adjusted_distribution(&p, &q).is_err());
}
use crate::model::mock::fixed_distribution;
use rand::SeedableRng;
fn tv_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
0.5 * a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.sum::<f32>()
}
fn empirical_first_token(
target_probs: Vec<f32>,
draft_probs: Vec<f32>,
trials: usize,
seed: u64,
config: &VanillaConfig,
) -> Vec<f32> {
let vocab = target_probs.len();
let mut counts = vec![0u64; vocab];
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
for _ in 0..trials {
let mut target = fixed_distribution(target_probs.clone());
let mut draft = fixed_distribution(draft_probs.clone());
let prompt = [0u32]; let out =
run_vanilla_sd_with_mock(&mut target, &mut draft, &prompt, 1, config, &mut rng)
.unwrap();
assert_eq!(out.len(), 1);
counts[out[0] as usize] += 1;
}
counts
.iter()
.map(|&c| c as f32 / trials as f32)
.collect::<Vec<_>>()
}
#[test]
fn sd_matches_target_when_draft_is_uniform() {
let target = vec![0.5, 0.3, 0.15, 0.05];
let draft = vec![0.25, 0.25, 0.25, 0.25];
let cfg = VanillaConfig {
draft_lookahead: 4,
temperature: 1.0,
top_p: 1.0,
};
let trials = 20_000;
let empirical = empirical_first_token(target.clone(), draft, trials, 11, &cfg);
let tv = tv_distance(&empirical, &target);
assert!(
tv < 0.025,
"TV distance {tv} too large; empirical={empirical:?}, target={target:?}"
);
}
#[test]
fn sd_matches_target_when_draft_is_skewed_opposite() {
let target = vec![0.7, 0.15, 0.1, 0.05];
let draft = vec![0.05, 0.1, 0.15, 0.7];
let cfg = VanillaConfig {
draft_lookahead: 4,
temperature: 1.0,
top_p: 1.0,
};
let trials = 20_000;
let empirical = empirical_first_token(target.clone(), draft, trials, 23, &cfg);
let tv = tv_distance(&empirical, &target);
assert!(
tv < 0.03,
"TV distance {tv} too large; empirical={empirical:?}, target={target:?}"
);
}
#[test]
fn sd_matches_target_when_draft_equals_target() {
let target = vec![0.4, 0.3, 0.2, 0.1];
let cfg = VanillaConfig {
draft_lookahead: 4,
temperature: 1.0,
top_p: 1.0,
};
let trials = 10_000;
let empirical = empirical_first_token(target.clone(), target.clone(), trials, 7, &cfg);
let tv = tv_distance(&empirical, &target);
assert!(
tv < 0.025,
"TV distance {tv}; empirical={empirical:?}, target={target:?}"
);
}
#[test]
fn sd_emits_only_supported_target_tokens() {
let target = vec![0.6, 0.4, 0.0, 0.0];
let draft = vec![0.25, 0.25, 0.25, 0.25];
let cfg = VanillaConfig {
draft_lookahead: 4,
temperature: 1.0,
top_p: 1.0,
};
let mut target_dec = fixed_distribution(target.clone());
let mut draft_dec = fixed_distribution(draft);
let mut rng = rand::rngs::StdRng::seed_from_u64(99);
for _ in 0..2_000 {
let out = run_vanilla_sd_with_mock(
&mut target_dec,
&mut draft_dec,
&[0u32],
1,
&cfg,
&mut rng,
)
.unwrap();
assert!(
(out[0] as usize) < 2,
"produced unsupported token {} from a target with zero mass on it",
out[0]
);
}
}
#[test]
fn sd_produces_max_new_tokens_count() {
let target = vec![0.4, 0.3, 0.2, 0.1];
let draft = vec![0.25, 0.25, 0.25, 0.25];
let cfg = VanillaConfig {
draft_lookahead: 3,
temperature: 1.0,
top_p: 1.0,
};
let mut t = fixed_distribution(target);
let mut d = fixed_distribution(draft);
let mut rng = rand::rngs::StdRng::seed_from_u64(1);
for n in [1usize, 5, 16, 32] {
let out = run_vanilla_sd_with_mock(&mut t, &mut d, &[0u32], n, &cfg, &mut rng).unwrap();
assert_eq!(
out.len(),
n,
"expected exactly {n} tokens, got {}",
out.len()
);
}
}
}