use crate::attn::speculative_decode::{
CacheCheckpoint, DraftCacheState, SpeculativeDecodeConfig, TargetCacheState,
};
use crate::error::DnnError;
#[must_use]
pub fn accept_token(draft_prob: f32, target_prob: f32, rng_sample: f32) -> bool {
if draft_prob <= 0.0 {
return false;
}
rng_sample < (target_prob / draft_prob).min(1.0)
}
impl DraftCacheState {
pub fn advance_token(&mut self, _token: u32) -> Result<(), DnnError> {
let seq_id = 0usize;
let pos = self.seq_positions[seq_id];
let logical_page = pos / self.page_size;
if logical_page >= self.page_tables[seq_id].len() {
let phys = self.free_pages.pop().ok_or_else(|| {
DnnError::WorkspaceRequired(self.num_heads * self.page_size * self.head_dim * 4)
})?;
self.page_tables[seq_id].push(phys);
}
self.seq_positions[seq_id] += 1;
self.total_tokens_generated += 1;
Ok(())
}
pub fn rollback_to(&mut self, position: usize) -> Result<(), DnnError> {
let seq_id = 0usize;
let current = self.seq_positions[seq_id];
if position > current {
return Err(DnnError::InvalidArgument(format!(
"rollback target {} exceeds current position {}",
position, current
)));
}
let pages_needed = position.div_ceil(self.page_size);
let current_pages = self.page_tables[seq_id].len();
if pages_needed < current_pages {
let excess: Vec<usize> = self.page_tables[seq_id].drain(pages_needed..).collect();
self.free_pages.extend(excess);
}
self.seq_positions[seq_id] = position;
Ok(())
}
#[must_use]
pub fn accepted_count(&self) -> usize {
self.seq_positions.first().copied().unwrap_or(0)
}
#[must_use]
pub fn rejected_count(&self) -> usize {
let pos = self.seq_positions.first().copied().unwrap_or(0);
self.total_tokens_generated.saturating_sub(pos)
}
}
#[derive(Debug, Clone)]
pub struct TokenVerificationResult {
pub accepted: usize,
pub correction: Option<u32>,
}
impl TargetCacheState {
#[must_use]
pub fn verify_tokens(
&self,
draft_tokens: &[u32],
draft_probs: &[f32],
target_probs: &[f32],
rng_samples: &[f32],
) -> TokenVerificationResult {
debug_assert_eq!(draft_tokens.len(), draft_probs.len());
debug_assert_eq!(draft_tokens.len(), target_probs.len());
debug_assert_eq!(draft_tokens.len(), rng_samples.len());
let gamma = draft_tokens.len();
for i in 0..gamma {
if !accept_token(draft_probs[i], target_probs[i], rng_samples[i]) {
let correction = self.sample_correction(i, target_probs, draft_probs);
return TokenVerificationResult {
accepted: i,
correction: Some(correction),
};
}
}
TokenVerificationResult {
accepted: gamma,
correction: None,
}
}
#[must_use]
pub fn sample_correction(
&self,
_position: usize,
target_probs: &[f32],
draft_probs: &[f32],
) -> u32 {
let residual: Vec<f32> = target_probs
.iter()
.zip(draft_probs.iter())
.map(|(&t, &d)| (t - d).max(0.0))
.collect();
let sum: f32 = residual.iter().sum();
if sum <= 0.0 {
let best = target_probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
return best as u32;
}
let best = residual
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
best as u32
}
}
#[derive(Debug, Clone)]
pub struct SpecDecConfig {
pub inner: SpeculativeDecodeConfig,
pub gamma: usize,
}
#[derive(Debug, Clone)]
pub struct SpecDecOutput {
pub accepted_tokens: Vec<u32>,
pub acceptance_rate: f32,
}
pub struct SpeculativeDecoder {
config: SpecDecConfig,
draft_state: DraftCacheState,
target_state: TargetCacheState,
active_checkpoint: Option<CacheCheckpoint>,
total_accepted: usize,
total_rejected: usize,
}
impl SpeculativeDecoder {
pub fn new(config: SpecDecConfig) -> Result<Self, DnnError> {
config.inner.validate()?;
let draft_state = DraftCacheState::new(&config.inner);
let target_state = TargetCacheState::new(&config.inner);
Ok(Self {
config,
draft_state,
target_state,
active_checkpoint: None,
total_accepted: 0,
total_rejected: 0,
})
}
#[must_use]
pub fn config(&self) -> &SpecDecConfig {
&self.config
}
#[must_use]
pub fn draft_state(&self) -> &DraftCacheState {
&self.draft_state
}
#[must_use]
pub fn target_state(&self) -> &TargetCacheState {
&self.target_state
}
#[must_use]
pub fn total_accepted(&self) -> usize {
self.total_accepted
}
#[must_use]
pub fn total_rejected(&self) -> usize {
self.total_rejected
}
pub fn step(
&mut self,
draft_tokens: &[u32],
draft_probs: &[f32],
target_probs: &[f32],
rng_samples: &[f32],
) -> Result<SpecDecOutput, DnnError> {
let gamma = draft_tokens.len();
if gamma == 0 {
return Err(DnnError::InvalidArgument(
"draft_tokens must not be empty".into(),
));
}
if draft_probs.len() != gamma || target_probs.len() != gamma || rng_samples.len() != gamma {
return Err(DnnError::InvalidArgument(format!(
"all slices must have length gamma={}, got draft_probs={}, target_probs={}, rng_samples={}",
gamma,
draft_probs.len(),
target_probs.len(),
rng_samples.len(),
)));
}
let draft_pos_before = self.draft_state.seq_positions[0];
let draft_pages_before = self.draft_state.page_tables[0].clone();
let draft_free_before = self.draft_state.free_pages.clone();
self.active_checkpoint = Some(CacheCheckpoint {
draft_positions: self.draft_state.seq_positions.clone(),
draft_page_tables: self.draft_state.page_tables.clone(),
draft_free_pages: self.draft_state.free_pages.clone(),
target_position: self.target_state.verified_position,
timestamp: self.total_accepted as u64,
});
for &tok in draft_tokens {
self.draft_state.advance_token(tok)?;
}
let verification =
self.target_state
.verify_tokens(draft_tokens, draft_probs, target_probs, rng_samples);
let accepted_count = verification.accepted;
let rejected_count = gamma - accepted_count;
self.total_accepted += accepted_count;
self.total_rejected += rejected_count;
if rejected_count > 0 {
self.draft_state.seq_positions = vec![draft_pos_before];
self.draft_state.page_tables = vec![draft_pages_before];
self.draft_state.free_pages = draft_free_before;
for &tok in draft_tokens.iter().take(accepted_count) {
self.draft_state.advance_token(tok)?;
}
}
self.target_state.verified_position += accepted_count;
let mut accepted_tokens: Vec<u32> = draft_tokens[..accepted_count].to_vec();
if let Some(correction) = verification.correction {
accepted_tokens.push(correction);
}
let acceptance_rate = accepted_count as f32 / gamma as f32;
self.active_checkpoint = None;
Ok(SpecDecOutput {
accepted_tokens,
acceptance_rate,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn base_config() -> SpeculativeDecodeConfig {
SpeculativeDecodeConfig {
draft_num_layers: 4,
draft_num_heads: 4,
draft_head_dim: 64,
target_num_layers: 12,
target_num_heads: 8,
target_head_dim: 128,
max_draft_tokens: 5,
page_size: 16,
max_pages: 32,
acceptance_threshold: 0.8,
}
}
fn dec_config() -> SpecDecConfig {
SpecDecConfig {
inner: base_config(),
gamma: 4,
}
}
#[test]
fn test_accept_token_always_accepts_when_target_gte_draft() {
assert!(accept_token(0.3, 0.3, 0.999));
assert!(accept_token(0.1, 0.5, 0.999)); assert!(accept_token(0.2, 0.8, 0.0));
}
#[test]
fn test_accept_token_rejects_when_zero_target_prob() {
assert!(!accept_token(0.5, 0.0, 0.0));
assert!(!accept_token(0.5, 0.0, 0.999));
}
#[test]
fn test_accept_token_rejects_when_draft_prob_zero() {
assert!(!accept_token(0.0, 0.5, 0.0));
}
#[test]
fn test_acceptance_rate_calculation() {
assert!(accept_token(0.8, 0.4, 0.3));
assert!(!accept_token(0.8, 0.4, 0.7));
}
#[test]
fn test_draft_state_advance_and_rollback() {
let mut state = DraftCacheState::new(&base_config());
state.advance_token(10).expect("advance 0");
state.advance_token(20).expect("advance 1");
assert_eq!(state.seq_positions[0], 2);
state.rollback_to(1).expect("rollback");
assert_eq!(state.seq_positions[0], 1);
}
#[test]
fn test_draft_state_rollback_beyond_position_errors() {
let mut state = DraftCacheState::new(&base_config());
state.advance_token(1).expect("advance");
let err = state.rollback_to(5);
assert!(err.is_err());
}
#[test]
fn test_draft_state_accepted_rejected_count() {
let mut state = DraftCacheState::new(&base_config());
for t in 0..5u32 {
state.advance_token(t).expect("advance");
}
assert_eq!(state.accepted_count(), 5);
assert_eq!(state.rejected_count(), 0);
}
#[test]
fn test_verification_all_accepted() {
let state = TargetCacheState::new(&base_config());
let draft_tokens = [1u32, 2, 3];
let draft_probs = [0.2f32, 0.3, 0.5];
let target_probs = [0.4f32, 0.6, 1.0]; let rng = [0.99f32, 0.99, 0.99];
let result = state.verify_tokens(&draft_tokens, &draft_probs, &target_probs, &rng);
assert_eq!(result.accepted, 3);
assert!(result.correction.is_none());
}
#[test]
fn test_verification_partial_acceptance() {
let state = TargetCacheState::new(&base_config());
let draft_tokens = [1u32, 2, 3];
let draft_probs = [0.6f32, 0.3, 0.5];
let _target_probs = [0.6f32, 0.3, 0.5]; let target_probs2 = [0.6f32, 0.1, 0.5]; let rng = [0.01f32, 0.5, 0.01];
let result = state.verify_tokens(&draft_tokens, &draft_probs, &target_probs2, &rng);
assert_eq!(result.accepted, 1); assert!(result.correction.is_some());
}
#[test]
fn test_verification_first_rejected() {
let state = TargetCacheState::new(&base_config());
let draft_tokens = [42u32];
let draft_probs = [0.9f32];
let target_probs = [0.1f32]; let rng = [0.5f32];
let result = state.verify_tokens(&draft_tokens, &draft_probs, &target_probs, &rng);
assert_eq!(result.accepted, 0);
assert!(result.correction.is_some());
}
#[test]
fn test_spec_decoder_all_accepted() {
let mut dec = SpeculativeDecoder::new(dec_config()).expect("new decoder");
let tokens = [1u32, 2, 3, 4];
let draft_p = [0.1f32; 4];
let target_p = [0.9f32; 4]; let rng = [0.99f32; 4];
let out = dec.step(&tokens, &draft_p, &target_p, &rng).expect("step");
assert_eq!(out.accepted_tokens, tokens);
assert!((out.acceptance_rate - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_spec_decoder_partial_acceptance_with_correction() {
let mut dec = SpeculativeDecoder::new(dec_config()).expect("new decoder");
let tokens = [1u32, 2, 3, 4];
let draft_p = [0.1f32, 0.8, 0.5, 0.5];
let target_p = [0.9f32, 0.1, 0.5, 0.5];
let rng = [0.01f32, 0.5, 0.01, 0.01];
let out = dec.step(&tokens, &draft_p, &target_p, &rng).expect("step");
assert_eq!(out.accepted_tokens.len(), 2);
assert_eq!(out.accepted_tokens[0], 1);
assert!(out.acceptance_rate < 1.0);
}
#[test]
fn test_spec_decoder_empty_tokens_error() {
let mut dec = SpeculativeDecoder::new(dec_config()).expect("new decoder");
let err = dec.step(&[], &[], &[], &[]);
assert!(err.is_err());
}
#[test]
fn test_spec_decoder_mismatched_lengths_error() {
let mut dec = SpeculativeDecoder::new(dec_config()).expect("new decoder");
let err = dec.step(&[1u32], &[0.5f32, 0.5], &[0.5f32], &[0.5f32]);
assert!(err.is_err());
}
#[test]
fn test_spec_decoder_running_totals() {
let mut dec = SpeculativeDecoder::new(dec_config()).expect("new decoder");
dec.step(&[1u32, 2, 3, 4], &[0.1f32; 4], &[0.9f32; 4], &[0.01f32; 4])
.expect("step 1");
assert_eq!(dec.total_accepted(), 4);
assert_eq!(dec.total_rejected(), 0);
}
}