use rlx_ir::Philox4x32;
#[derive(Debug, Clone)]
pub struct DraftProposal {
pub tokens: Vec<u32>,
pub probs: Vec<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct VerifyResult {
pub probs: Vec<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct AcceptDecision {
pub accepted: Vec<u32>,
pub corrected: Option<u32>,
}
impl AcceptDecision {
pub fn total_tokens(&self) -> usize {
self.accepted.len() + if self.corrected.is_some() { 1 } else { 0 }
}
}
pub trait Speculator {
fn propose(&mut self, context: &[u32], n: usize) -> DraftProposal;
fn verify(&mut self, context: &[u32], proposed: &[u32]) -> VerifyResult;
fn commit(&mut self, context: &[u32], accepted: &[u32]) {
let _ = (context, accepted);
}
}
pub fn speculative_accept(
proposal: &DraftProposal,
verify: &VerifyResult,
rng: &mut Philox4x32,
) -> AcceptDecision {
assert_eq!(
proposal.tokens.len(),
proposal.probs.len(),
"DraftProposal: tokens and probs must agree"
);
assert_eq!(
proposal.probs.len(),
verify.probs.len(),
"DraftProposal and VerifyResult must propose the same n"
);
let n = proposal.tokens.len();
let mut accepted: Vec<u32> = Vec::with_capacity(n);
for i in 0..n {
let token = proposal.tokens[i];
let p = proposal.probs[i][token as usize].max(f32::MIN_POSITIVE);
let q = verify.probs[i][token as usize];
let accept_ratio = (q / p).min(1.0);
let r = rng.next_f32();
if r < accept_ratio {
accepted.push(token);
} else {
let corrected = sample_corrected_residual(&proposal.probs[i], &verify.probs[i], rng);
return AcceptDecision {
accepted,
corrected: Some(corrected),
};
}
}
AcceptDecision {
accepted,
corrected: None,
}
}
fn sample_corrected_residual(p: &[f32], q: &[f32], rng: &mut Philox4x32) -> u32 {
let mut adj: Vec<f32> = q.iter().zip(p).map(|(qi, pi)| (qi - pi).max(0.0)).collect();
let sum: f32 = adj.iter().sum();
if sum <= f32::MIN_POSITIVE {
return sample_from(q, rng);
}
let inv = 1.0 / sum;
for v in adj.iter_mut() {
*v *= inv;
}
sample_from(&adj, rng)
}
fn sample_from(probs: &[f32], rng: &mut Philox4x32) -> u32 {
let r = rng.next_f32();
let mut acc = 0f32;
for (i, &p) in probs.iter().enumerate() {
acc += p;
if r <= acc {
return i as u32;
}
}
(probs.len() - 1) as u32
}
pub struct SpecDecoder<D: Speculator, T: Speculator> {
pub draft: D,
pub target: T,
pub n: usize,
rng: Philox4x32,
}
impl<D: Speculator, T: Speculator> SpecDecoder<D, T> {
pub fn new(draft: D, target: T, n: usize, seed: u64) -> Self {
Self {
draft,
target,
n,
rng: Philox4x32::new(seed),
}
}
pub fn step(&mut self, context: &[u32]) -> Vec<u32> {
let proposal = self.draft.propose(context, self.n);
let verify = self.target.verify(context, &proposal.tokens);
let decision = speculative_accept(&proposal, &verify, &mut self.rng);
let mut out = decision.accepted;
if let Some(c) = decision.corrected {
out.push(c);
}
self.draft.commit(context, &out);
self.target.commit(context, &out);
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identical_distributions_accept_all() {
let n = 4;
let vocab = 8;
let mut probs = Vec::with_capacity(n);
let mut tokens = Vec::with_capacity(n);
for i in 0..n {
let mut row = vec![0.01f32; vocab];
let pick = (i * 2) % vocab;
row[pick] = 1.0 - 0.01 * (vocab - 1) as f32;
probs.push(row);
tokens.push(pick as u32);
}
let proposal = DraftProposal {
tokens: tokens.clone(),
probs: probs.clone(),
};
let verify = VerifyResult { probs };
for seed in 0..100u64 {
let mut rng = Philox4x32::new(seed + 1);
let d = speculative_accept(&proposal, &verify, &mut rng);
assert_eq!(d.accepted, tokens, "seed {seed}: should accept all");
assert!(d.corrected.is_none());
}
}
#[test]
fn divergent_distributions_reject_sometimes() {
let n = 4;
let _vocab = 4;
let draft_row = vec![0.97f32, 0.01, 0.01, 0.01];
let target_row = vec![0.01f32, 0.01, 0.01, 0.97];
let proposal = DraftProposal {
tokens: vec![0u32; n],
probs: vec![draft_row.clone(); n],
};
let verify = VerifyResult {
probs: vec![target_row.clone(); n],
};
let mut total_accepted = 0usize;
let trials = 200;
for seed in 0..trials {
let mut rng = Philox4x32::new(seed + 1);
let d = speculative_accept(&proposal, &verify, &mut rng);
total_accepted += d.accepted.len();
if d.accepted.len() < n {
assert!(
d.corrected.is_some(),
"rejection at seed {seed} should yield a corrected token"
);
}
}
assert!(
total_accepted < 80,
"divergent distributions should accept rarely; got {total_accepted}/800"
);
}
struct CannedSpeculator {
next_token: u32,
peaked_prob: f32,
}
impl Speculator for CannedSpeculator {
fn propose(&mut self, _ctx: &[u32], n: usize) -> DraftProposal {
let vocab = 8;
let mut probs = Vec::with_capacity(n);
for _ in 0..n {
let mut row = vec![(1.0 - self.peaked_prob) / (vocab - 1) as f32; vocab];
row[self.next_token as usize] = self.peaked_prob;
probs.push(row);
}
DraftProposal {
tokens: vec![self.next_token; n],
probs,
}
}
fn verify(&mut self, _ctx: &[u32], proposed: &[u32]) -> VerifyResult {
let n = proposed.len();
let vocab = 8;
let mut probs = Vec::with_capacity(n);
for _ in 0..n {
let mut row = vec![(1.0 - self.peaked_prob) / (vocab - 1) as f32; vocab];
row[self.next_token as usize] = self.peaked_prob;
probs.push(row);
}
VerifyResult { probs }
}
}
#[test]
fn spec_decoder_step_emits_n_plus_1_tokens_when_aligned() {
let draft = CannedSpeculator {
next_token: 5,
peaked_prob: 0.95,
};
let target = CannedSpeculator {
next_token: 5,
peaked_prob: 0.95,
};
let mut dec = SpecDecoder::new(draft, target, 4, 1);
let context = vec![0u32, 1, 2];
let out = dec.step(&context);
assert_eq!(
out.len(),
4,
"aligned step should emit n tokens (no rejection)"
);
assert!(out.iter().all(|&t| t == 5));
}
}