use std::marker::PhantomData;
use scirs2_core::random::{SeedableRng, StdRng};
use crate::speculative_decoding::acceptance::{
accept, resample_from_adjusted_target, sample_from_logprobs,
};
use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
use crate::speculative_decoding::metrics::SpeculativeMetrics;
use crate::speculative_decoding::rng::SpecRng;
use crate::speculative_decoding::traits::{
DraftModel, DraftProposal, TargetModel, TargetScores, TokenId,
};
#[derive(Debug, Clone, PartialEq)]
pub struct SpeculativeDecoderConfig {
pub k: usize,
pub cost_ratio: f32,
pub stop_on_eos: bool,
pub eos_token: Option<TokenId>,
}
impl Default for SpeculativeDecoderConfig {
fn default() -> Self {
Self {
k: 4,
cost_ratio: 0.125,
stop_on_eos: false,
eos_token: None,
}
}
}
impl SpeculativeDecoderConfig {
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_cost_ratio(mut self, r: f32) -> Self {
self.cost_ratio = r;
self
}
pub fn with_eos(mut self, eos: TokenId) -> Self {
self.eos_token = Some(eos);
self.stop_on_eos = true;
self
}
pub fn validate(&self) -> SpeculativeDecodingResult<()> {
if self.k == 0 {
return Err(SpeculativeDecodingError::InvalidConfig(
"draft depth `k` must be at least 1".into(),
));
}
Ok(())
}
}
pub struct SpeculativeDecoder<D: DraftModel, T: TargetModel> {
draft: D,
target: T,
config: SpeculativeDecoderConfig,
metrics: SpeculativeMetrics,
_pd: PhantomData<()>,
}
impl<D: DraftModel + std::fmt::Debug, T: TargetModel + std::fmt::Debug> std::fmt::Debug
for SpeculativeDecoder<D, T>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpeculativeDecoder")
.field("draft", &self.draft)
.field("target", &self.target)
.field("config", &self.config)
.field("metrics", &self.metrics)
.finish()
}
}
impl<D: DraftModel, T: TargetModel> SpeculativeDecoder<D, T> {
pub fn new(
draft: D,
target: T,
config: SpeculativeDecoderConfig,
) -> SpeculativeDecodingResult<Self> {
config.validate()?;
if draft.vocab_size() != target.vocab_size() {
return Err(SpeculativeDecodingError::VocabMismatch {
draft: draft.vocab_size(),
target: target.vocab_size(),
});
}
let metrics = SpeculativeMetrics::new().with_cost_ratio(config.cost_ratio);
Ok(Self {
draft,
target,
config,
metrics,
_pd: PhantomData,
})
}
pub fn metrics(&self) -> &SpeculativeMetrics {
&self.metrics
}
pub fn reset_metrics(&mut self) {
self.metrics.reset();
}
pub fn config(&self) -> &SpeculativeDecoderConfig {
&self.config
}
pub fn generate(
&mut self,
prefix: &[TokenId],
max_tokens: usize,
) -> SpeculativeDecodingResult<Vec<TokenId>> {
let mut rng = StdRng::seed_from_u64(42);
self.generate_with_rng(prefix, max_tokens, &mut rng)
}
pub fn generate_with_rng(
&mut self,
prefix: &[TokenId],
max_tokens: usize,
rng: &mut dyn SpecRng,
) -> SpeculativeDecodingResult<Vec<TokenId>> {
if prefix.is_empty() {
return Err(SpeculativeDecodingError::EmptyPrefix);
}
let vocab = self.draft.vocab_size();
let k = self.config.k;
let mut working = prefix.to_vec();
let mut output: Vec<TokenId> = Vec::with_capacity(max_tokens);
while output.len() < max_tokens {
let remaining = max_tokens - output.len();
let round_k = k.min(remaining.max(1));
let proposal = self.draft.propose(&working, round_k, rng)?;
validate_proposal(&proposal, round_k, vocab)?;
let target_scores = self.target.verify(&working, &proposal.tokens)?;
validate_target_scores(&target_scores, round_k, vocab)?;
let (accepted_count, emitted) =
run_rejection_loop(&proposal, &target_scores, round_k, vocab, rng)?;
let mut committed_this_round = 0u32;
for token in emitted.into_iter() {
output.push(token);
working.push(token);
committed_this_round += 1;
if output.len() >= max_tokens {
break;
}
if self.config.stop_on_eos
&& self
.config
.eos_token
.map(|eos| eos == token)
.unwrap_or(false)
{
break;
}
}
self.metrics.record_round(
round_k as u32,
accepted_count as u32,
committed_this_round,
round_k as u32,
);
if self.config.stop_on_eos {
if let Some(eos) = self.config.eos_token {
if output.last().copied() == Some(eos) {
break;
}
}
}
}
Ok(output)
}
}
fn validate_proposal(p: &DraftProposal, k: usize, vocab: usize) -> SpeculativeDecodingResult<()> {
if p.tokens.len() != k || p.token_logprobs.len() != k || p.distributions.len() != k {
return Err(SpeculativeDecodingError::DraftShapeMismatch {
tokens: p.tokens.len(),
logprobs: p.token_logprobs.len(),
distributions: p.distributions.len(),
});
}
for row in &p.distributions {
if row.len() != vocab {
return Err(SpeculativeDecodingError::DistributionWidthMismatch {
expected: vocab,
got: row.len(),
});
}
}
for &t in &p.tokens {
if t >= vocab {
return Err(SpeculativeDecodingError::TokenOutOfRange {
token: t,
vocab_size: vocab,
});
}
}
Ok(())
}
fn validate_target_scores(
t: &TargetScores,
k: usize,
vocab: usize,
) -> SpeculativeDecodingResult<()> {
if t.distributions.len() != k + 1 {
return Err(SpeculativeDecodingError::TargetShapeMismatch {
expected: k + 1,
got: t.distributions.len(),
});
}
for row in &t.distributions {
if row.len() != vocab {
return Err(SpeculativeDecodingError::DistributionWidthMismatch {
expected: vocab,
got: row.len(),
});
}
}
Ok(())
}
fn run_rejection_loop(
proposal: &DraftProposal,
target_scores: &TargetScores,
k: usize,
vocab: usize,
rng: &mut dyn SpecRng,
) -> SpeculativeDecodingResult<(usize, Vec<TokenId>)> {
let mut emitted: Vec<TokenId> = Vec::with_capacity(k + 1);
let mut accepted: usize = 0;
for i in 0..k {
let draft_token = proposal.tokens[i];
let target_row = &target_scores.distributions[i];
let draft_row = &proposal.distributions[i];
let draft_lp = draft_row[draft_token];
let target_lp = target_row[draft_token];
if accept(draft_lp, target_lp, rng) {
emitted.push(draft_token);
accepted += 1;
continue;
}
let resampled = resample_from_adjusted_target(target_row, draft_row, rng)?;
if resampled >= vocab {
return Err(SpeculativeDecodingError::TokenOutOfRange {
token: resampled,
vocab_size: vocab,
});
}
emitted.push(resampled);
return Ok((accepted, emitted));
}
let bonus_row = &target_scores.distributions[k];
let bonus = sample_from_logprobs(bonus_row, rng)?;
if bonus >= vocab {
return Err(SpeculativeDecodingError::TokenOutOfRange {
token: bonus,
vocab_size: vocab,
});
}
emitted.push(bonus);
Ok((accepted, emitted))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_is_sensible() {
let c = SpeculativeDecoderConfig::default();
assert_eq!(c.k, 4);
assert!(c.validate().is_ok());
}
#[test]
fn config_k_zero_rejected() {
let c = SpeculativeDecoderConfig::default().with_k(0);
assert!(c.validate().is_err());
}
#[test]
fn config_builders_compose() {
let c = SpeculativeDecoderConfig::default()
.with_k(2)
.with_cost_ratio(0.05)
.with_eos(7);
assert_eq!(c.k, 2);
assert!((c.cost_ratio - 0.05).abs() < 1e-6);
assert_eq!(c.eos_token, Some(7));
assert!(c.stop_on_eos);
}
}