use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
use crate::speculative_decoding::rng::SpecRng;
use crate::speculative_decoding::traits::{
DraftModel, DraftProposal, LogProb, TargetModel, TargetScores, TokenId,
};
#[derive(Debug, Clone)]
pub struct FixedDistDraftModel {
probs: Vec<f64>,
logprobs: Vec<LogProb>,
}
impl FixedDistDraftModel {
pub fn new(probs: Vec<f64>) -> SpeculativeDecodingResult<Self> {
if probs.is_empty() {
return Err(SpeculativeDecodingError::InvalidConfig(
"FixedDistDraftModel requires a non-empty probability vector".into(),
));
}
let sum: f64 = probs.iter().copied().sum();
if !(sum > 0.0 && sum.is_finite()) {
return Err(SpeculativeDecodingError::InvalidConfig(
"FixedDistDraftModel probabilities must have positive finite mass".into(),
));
}
let normalized: Vec<f64> = probs.iter().map(|p| *p / sum).collect();
let logprobs: Vec<f64> = normalized
.iter()
.map(|p| if *p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
.collect();
Ok(Self {
probs: normalized,
logprobs,
})
}
pub fn probs(&self) -> &[f64] {
&self.probs
}
pub fn logprobs(&self) -> &[LogProb] {
&self.logprobs
}
}
impl DraftModel for FixedDistDraftModel {
fn vocab_size(&self) -> usize {
self.probs.len()
}
fn propose(
&self,
_prefix: &[TokenId],
k: usize,
rng: &mut dyn SpecRng,
) -> SpeculativeDecodingResult<DraftProposal> {
let mut tokens = Vec::with_capacity(k);
let mut token_logprobs = Vec::with_capacity(k);
let mut distributions = Vec::with_capacity(k);
for _ in 0..k {
let idx = sample_categorical(&self.probs, rng)?;
tokens.push(idx);
token_logprobs.push(self.logprobs[idx]);
distributions.push(self.logprobs.clone());
}
Ok(DraftProposal {
tokens,
token_logprobs,
distributions,
})
}
}
#[derive(Debug, Clone)]
pub struct FixedDistTargetModel {
probs: Vec<f64>,
logprobs: Vec<LogProb>,
}
impl FixedDistTargetModel {
pub fn new(probs: Vec<f64>) -> SpeculativeDecodingResult<Self> {
if probs.is_empty() {
return Err(SpeculativeDecodingError::InvalidConfig(
"FixedDistTargetModel requires a non-empty probability vector".into(),
));
}
let sum: f64 = probs.iter().copied().sum();
if !(sum > 0.0 && sum.is_finite()) {
return Err(SpeculativeDecodingError::InvalidConfig(
"FixedDistTargetModel probabilities must have positive finite mass".into(),
));
}
let normalized: Vec<f64> = probs.iter().map(|p| *p / sum).collect();
let logprobs: Vec<f64> = normalized
.iter()
.map(|p| if *p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
.collect();
Ok(Self {
probs: normalized,
logprobs,
})
}
pub fn probs(&self) -> &[f64] {
&self.probs
}
pub fn logprobs(&self) -> &[LogProb] {
&self.logprobs
}
}
impl TargetModel for FixedDistTargetModel {
fn vocab_size(&self) -> usize {
self.probs.len()
}
fn verify(
&self,
_prefix: &[TokenId],
draft_tokens: &[TokenId],
) -> SpeculativeDecodingResult<TargetScores> {
let rows = draft_tokens.len() + 1;
let distributions: Vec<Vec<LogProb>> = (0..rows).map(|_| self.logprobs.clone()).collect();
Ok(TargetScores { distributions })
}
}
pub(crate) fn sample_categorical(
probs: &[f64],
rng: &mut dyn SpecRng,
) -> SpeculativeDecodingResult<TokenId> {
if probs.is_empty() {
return Err(SpeculativeDecodingError::DegenerateDistribution);
}
let u = rng.next_unit_f64();
let mut cum = 0.0;
for (i, p) in probs.iter().enumerate() {
cum += *p;
if u < cum {
return Ok(i);
}
}
Ok(probs.len() - 1)
}
pub type MockDraftModel = FixedDistDraftModel;
pub type MockTargetModel = FixedDistTargetModel;
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::{SeedableRng, StdRng};
#[test]
fn draft_model_normalizes_input() {
let d = FixedDistDraftModel::new(vec![2.0, 2.0]).expect("normalize");
for p in d.probs() {
assert!((p - 0.5).abs() < 1e-9);
}
}
#[test]
fn draft_model_rejects_empty() {
let r = FixedDistDraftModel::new(vec![]);
assert!(r.is_err());
}
#[test]
fn draft_model_rejects_zero_mass() {
let r = FixedDistDraftModel::new(vec![0.0, 0.0, 0.0]);
assert!(r.is_err());
}
#[test]
fn propose_shapes_are_consistent() {
let d = FixedDistDraftModel::new(vec![0.25; 4]).expect("d");
let mut rng = StdRng::seed_from_u64(1);
let p = d.propose(&[0, 1, 2], 3, &mut rng).expect("propose");
assert_eq!(p.tokens.len(), 3);
assert_eq!(p.token_logprobs.len(), 3);
assert_eq!(p.distributions.len(), 3);
for row in &p.distributions {
assert_eq!(row.len(), 4);
}
}
#[test]
fn verify_returns_k_plus_one_rows() {
let t = FixedDistTargetModel::new(vec![0.25; 4]).expect("t");
let ts = t.verify(&[0, 1], &[1, 2, 3]).expect("verify");
assert_eq!(ts.distributions.len(), 4);
for row in &ts.distributions {
assert_eq!(row.len(), 4);
}
}
#[test]
fn propose_is_reproducible_with_seed() {
let d = FixedDistDraftModel::new(vec![0.1, 0.2, 0.3, 0.4]).expect("d");
let mut r1 = StdRng::seed_from_u64(7);
let mut r2 = StdRng::seed_from_u64(7);
let p1 = d.propose(&[0], 8, &mut r1).expect("p1");
let p2 = d.propose(&[0], 8, &mut r2).expect("p2");
assert_eq!(p1.tokens, p2.tokens);
}
}