use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
use crate::speculative_decoding::rng::SpecRng;
use crate::speculative_decoding::traits::{LogProb, TokenId};
pub fn accept(draft_logprob: LogProb, target_logprob: LogProb, rng: &mut dyn SpecRng) -> bool {
match (draft_logprob.is_finite(), target_logprob.is_finite()) {
(false, true) => return true,
(true, false) => return false,
(false, false) => return false,
(true, true) => {}
}
let log_ratio = target_logprob - draft_logprob;
if log_ratio >= 0.0 {
return true;
}
let ratio = log_ratio.exp();
let u = rng.next_unit_f64();
u < ratio
}
pub fn adjusted_distribution(
target_logprobs: &[LogProb],
draft_logprobs: &[LogProb],
) -> SpeculativeDecodingResult<Vec<f64>> {
if target_logprobs.len() != draft_logprobs.len() {
return Err(SpeculativeDecodingError::DistributionWidthMismatch {
expected: target_logprobs.len(),
got: draft_logprobs.len(),
});
}
let target: Vec<f64> = target_logprobs
.iter()
.map(|lp| if lp.is_finite() { lp.exp() } else { 0.0 })
.collect();
let draft: Vec<f64> = draft_logprobs
.iter()
.map(|lp| if lp.is_finite() { lp.exp() } else { 0.0 })
.collect();
let mut adjusted: Vec<f64> = target
.iter()
.zip(draft.iter())
.map(|(t, d)| (t - d).max(0.0))
.collect();
let mass: f64 = adjusted.iter().sum();
if mass > 0.0 {
for x in adjusted.iter_mut() {
*x /= mass;
}
return Ok(adjusted);
}
let target_mass: f64 = target.iter().sum();
if target_mass <= 0.0 {
return Err(SpeculativeDecodingError::DegenerateDistribution);
}
for x in target.iter() {
debug_assert!(*x >= 0.0);
}
let fallback: Vec<f64> = target.iter().map(|t| t / target_mass).collect();
Ok(fallback)
}
pub fn sample_index(probs: &[f64], rng: &mut dyn SpecRng) -> SpeculativeDecodingResult<TokenId> {
if probs.is_empty() {
return Err(SpeculativeDecodingError::DegenerateDistribution);
}
let u = rng.next_unit_f64();
let mut cum = 0.0;
for (i, p) in probs.iter().enumerate() {
cum += *p;
if u < cum {
return Ok(i);
}
}
Ok(probs.len() - 1)
}
pub fn resample_from_adjusted_target(
target_logprobs: &[LogProb],
draft_logprobs: &[LogProb],
rng: &mut dyn SpecRng,
) -> SpeculativeDecodingResult<TokenId> {
let adjusted = adjusted_distribution(target_logprobs, draft_logprobs)?;
sample_index(&adjusted, rng)
}
pub fn sample_from_logprobs(
logprobs: &[LogProb],
rng: &mut dyn SpecRng,
) -> SpeculativeDecodingResult<TokenId> {
if logprobs.is_empty() {
return Err(SpeculativeDecodingError::DegenerateDistribution);
}
let max = logprobs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if !max.is_finite() {
return Err(SpeculativeDecodingError::DegenerateDistribution);
}
let exps: Vec<f64> = logprobs.iter().map(|lp| (lp - max).exp()).collect();
let mass: f64 = exps.iter().sum();
if mass <= 0.0 {
return Err(SpeculativeDecodingError::DegenerateDistribution);
}
let probs: Vec<f64> = exps.iter().map(|e| e / mass).collect();
sample_index(&probs, rng)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::{SeedableRng, StdRng};
fn rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
#[test]
fn accept_when_target_dominates() {
let mut r = rng(1);
for _ in 0..100 {
assert!(accept(-2.0, -0.1, &mut r));
}
}
#[test]
fn reject_when_draft_much_more_likely() {
let mut r = rng(2);
let rejects = (0..1000).filter(|_| !accept(-0.1, -3.1, &mut r)).count();
assert!(rejects > 900, "expected >90% rejects, got {}", rejects);
}
#[test]
fn adjusted_nonnegative_and_normalized() {
let tgt: Vec<f64> = vec![0.1f64.ln(), 0.2f64.ln(), 0.7f64.ln()];
let drf: Vec<f64> = vec![0.4f64.ln(), 0.4f64.ln(), 0.2f64.ln()];
let q = adjusted_distribution(&tgt, &drf).expect("adjusted");
assert_eq!(q.len(), 3);
for &p in &q {
assert!(p >= 0.0);
}
let sum: f64 = q.iter().sum();
assert!((sum - 1.0).abs() < 1e-9);
assert!((q[0] - 0.0).abs() < 1e-9);
assert!((q[1] - 0.0).abs() < 1e-9);
assert!((q[2] - 1.0).abs() < 1e-9);
}
#[test]
fn adjusted_falls_back_to_target_when_equal() {
let tgt: Vec<f64> = vec![0.25f64.ln(); 4];
let drf: Vec<f64> = vec![0.25f64.ln(); 4];
let q = adjusted_distribution(&tgt, &drf).expect("adjusted");
assert!((q.iter().sum::<f64>() - 1.0).abs() < 1e-9);
for &p in &q {
assert!((p - 0.25).abs() < 1e-9);
}
}
#[test]
fn adjusted_mismatched_widths_errors() {
let tgt: Vec<f64> = vec![0.5f64.ln(); 3];
let drf: Vec<f64> = vec![0.5f64.ln(); 4];
let res = adjusted_distribution(&tgt, &drf);
assert!(res.is_err());
}
#[test]
fn sample_index_obeys_distribution() {
let mut r = rng(42);
let p = vec![0.1, 0.2, 0.3, 0.4];
let n = 10_000;
let mut counts = [0usize; 4];
for _ in 0..n {
let idx = sample_index(&p, &mut r).expect("sample");
counts[idx] += 1;
}
let emp: Vec<f64> = counts.iter().map(|&c| c as f64 / n as f64).collect();
for (e, t) in emp.iter().zip(p.iter()) {
assert!((e - t).abs() < 0.03, "emp {:?} vs {:?}", emp, p);
}
}
#[test]
fn resample_returns_in_range() {
let mut r = rng(7);
let tgt: Vec<f64> = vec![0.1f64.ln(), 0.2f64.ln(), 0.7f64.ln()];
let drf: Vec<f64> = vec![0.4f64.ln(), 0.4f64.ln(), 0.2f64.ln()];
for _ in 0..200 {
let idx = resample_from_adjusted_target(&tgt, &drf, &mut r).expect("resample");
assert!(idx < 3);
}
}
#[test]
fn sample_from_logprobs_is_normalized() {
let mut r = rng(13);
let lp = vec![-0.1, -1.0, -2.0, -3.0];
let mut counts = [0usize; 4];
for _ in 0..5_000 {
let idx = sample_from_logprobs(&lp, &mut r).expect("sample");
counts[idx] += 1;
}
assert!(counts[0] > counts[1]);
assert!(counts[1] > counts[2]);
assert!(counts[2] > counts[3]);
}
}