pub mod decoder;
pub mod kv_manager;
pub use decoder::{
SpecDecConfig, SpecDecOutput, SpeculativeDecoder, TokenVerificationResult, accept_token,
};
pub use kv_manager::{KvCheckpoint, KvManager};
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::DnnError;
const SPEC_BLOCK_SIZE: u32 = 256;
#[derive(Debug, Clone)]
pub struct SpeculativeDecodeConfig {
pub draft_num_layers: usize,
pub draft_num_heads: usize,
pub draft_head_dim: usize,
pub target_num_layers: usize,
pub target_num_heads: usize,
pub target_head_dim: usize,
pub max_draft_tokens: usize,
pub page_size: usize,
pub max_pages: usize,
pub acceptance_threshold: f32,
}
impl SpeculativeDecodeConfig {
pub fn validate(&self) -> Result<(), DnnError> {
if self.draft_num_layers == 0 {
return Err(DnnError::InvalidArgument(
"draft_num_layers must be > 0".into(),
));
}
if self.draft_num_heads == 0 {
return Err(DnnError::InvalidArgument(
"draft_num_heads must be > 0".into(),
));
}
if self.draft_head_dim == 0 {
return Err(DnnError::InvalidArgument(
"draft_head_dim must be > 0".into(),
));
}
if self.target_num_layers == 0 {
return Err(DnnError::InvalidArgument(
"target_num_layers must be > 0".into(),
));
}
if self.target_num_heads == 0 {
return Err(DnnError::InvalidArgument(
"target_num_heads must be > 0".into(),
));
}
if self.target_head_dim == 0 {
return Err(DnnError::InvalidArgument(
"target_head_dim must be > 0".into(),
));
}
if self.max_draft_tokens == 0 {
return Err(DnnError::InvalidArgument(
"max_draft_tokens must be > 0".into(),
));
}
if self.page_size == 0 {
return Err(DnnError::InvalidArgument("page_size must be > 0".into()));
}
if self.max_pages == 0 {
return Err(DnnError::InvalidArgument("max_pages must be > 0".into()));
}
if !(0.0..=1.0).contains(&self.acceptance_threshold) {
return Err(DnnError::InvalidArgument(format!(
"acceptance_threshold must be in [0.0, 1.0], got {}",
self.acceptance_threshold,
)));
}
Ok(())
}
#[must_use]
pub fn draft_page_elements(&self) -> usize {
self.draft_num_heads * self.page_size * self.draft_head_dim
}
#[must_use]
pub fn target_page_elements(&self) -> usize {
self.target_num_heads * self.page_size * self.target_head_dim
}
}
#[derive(Debug, Clone)]
pub struct DraftCacheState {
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub page_size: usize,
pub max_pages: usize,
pub seq_positions: Vec<usize>,
pub page_tables: Vec<Vec<usize>>,
pub free_pages: Vec<usize>,
pub total_tokens_generated: usize,
}
impl DraftCacheState {
fn new(config: &SpeculativeDecodeConfig) -> Self {
let free_pages: Vec<usize> = (0..config.max_pages).collect();
Self {
num_layers: config.draft_num_layers,
num_heads: config.draft_num_heads,
head_dim: config.draft_head_dim,
page_size: config.page_size,
max_pages: config.max_pages,
seq_positions: vec![0],
page_tables: vec![Vec::new()],
free_pages,
total_tokens_generated: 0,
}
}
#[must_use]
pub fn allocated_pages(&self) -> usize {
self.max_pages - self.free_pages.len()
}
}
#[derive(Debug, Clone)]
pub struct TargetCacheState {
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub page_size: usize,
pub max_pages: usize,
pub seq_positions: Vec<usize>,
pub page_tables: Vec<Vec<usize>>,
pub free_pages: Vec<usize>,
pub verified_position: usize,
}
impl TargetCacheState {
fn new(config: &SpeculativeDecodeConfig) -> Self {
let free_pages: Vec<usize> = (0..config.max_pages).collect();
Self {
num_layers: config.target_num_layers,
num_heads: config.target_num_heads,
head_dim: config.target_head_dim,
page_size: config.page_size,
max_pages: config.max_pages,
seq_positions: vec![0],
page_tables: vec![Vec::new()],
free_pages,
verified_position: 0,
}
}
#[must_use]
pub fn allocated_pages(&self) -> usize {
self.max_pages - self.free_pages.len()
}
}
#[derive(Debug, Clone)]
pub struct CacheCheckpoint {
pub draft_positions: Vec<usize>,
pub draft_page_tables: Vec<Vec<usize>>,
pub draft_free_pages: Vec<usize>,
pub target_position: usize,
pub timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct VerificationResult {
pub accepted_count: usize,
pub total_drafted: usize,
pub acceptance_rate: f32,
pub rolled_back_pages: usize,
pub bonus_token: bool,
}
#[derive(Debug, Clone, Default)]
pub struct SpeculativeDecodeStats {
pub draft_pages_allocated: usize,
pub target_pages_allocated: usize,
pub total_checkpoints_created: usize,
pub total_rollbacks: usize,
pub average_acceptance_rate: f32,
}
pub struct SpeculativeKvManager {
config: SpeculativeDecodeConfig,
draft_cache: DraftCacheState,
target_cache: TargetCacheState,
checkpoint: Option<CacheCheckpoint>,
checkpoint_counter: u64,
total_checkpoints: usize,
total_rollbacks: usize,
acceptance_sum: f64,
acceptance_count: usize,
}
impl SpeculativeKvManager {
pub fn new(config: SpeculativeDecodeConfig) -> Result<Self, DnnError> {
config.validate()?;
let draft_cache = DraftCacheState::new(&config);
let target_cache = TargetCacheState::new(&config);
Ok(Self {
config,
draft_cache,
target_cache,
checkpoint: None,
checkpoint_counter: 0,
total_checkpoints: 0,
total_rollbacks: 0,
acceptance_sum: 0.0,
acceptance_count: 0,
})
}
pub fn create_checkpoint(&mut self, seq_id: usize) -> Result<(), DnnError> {
if seq_id >= self.draft_cache.seq_positions.len() {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.draft_cache.seq_positions.len(),
)));
}
self.checkpoint_counter += 1;
self.total_checkpoints += 1;
self.checkpoint = Some(CacheCheckpoint {
draft_positions: self.draft_cache.seq_positions.clone(),
draft_page_tables: self.draft_cache.page_tables.clone(),
draft_free_pages: self.draft_cache.free_pages.clone(),
target_position: self.target_cache.verified_position,
timestamp: self.checkpoint_counter,
});
Ok(())
}
pub fn append_draft_kv(&mut self, seq_id: usize) -> Result<(usize, usize), DnnError> {
if seq_id >= self.draft_cache.seq_positions.len() {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.draft_cache.seq_positions.len(),
)));
}
let pos = self.draft_cache.seq_positions[seq_id];
let logical_page = pos / self.draft_cache.page_size;
let offset_in_page = pos % self.draft_cache.page_size;
if logical_page >= self.draft_cache.page_tables[seq_id].len() {
let phys = self
.draft_cache
.free_pages
.pop()
.ok_or(DnnError::WorkspaceRequired(
self.config.draft_page_elements() * 4, ))?;
self.draft_cache.page_tables[seq_id].push(phys);
}
let phys_page = self.draft_cache.page_tables[seq_id][logical_page];
let page_elements = self.config.draft_page_elements();
let token_stride = self.draft_cache.num_heads * self.draft_cache.head_dim;
let k_offset = phys_page * page_elements + offset_in_page * token_stride;
let v_offset = k_offset;
self.draft_cache.seq_positions[seq_id] += 1;
self.draft_cache.total_tokens_generated += 1;
Ok((k_offset, v_offset))
}
pub fn append_target_kv(&mut self, seq_id: usize) -> Result<(usize, usize), DnnError> {
if seq_id >= self.target_cache.seq_positions.len() {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.target_cache.seq_positions.len(),
)));
}
let pos = self.target_cache.seq_positions[seq_id];
let logical_page = pos / self.target_cache.page_size;
let offset_in_page = pos % self.target_cache.page_size;
if logical_page >= self.target_cache.page_tables[seq_id].len() {
let phys = self
.target_cache
.free_pages
.pop()
.ok_or(DnnError::WorkspaceRequired(
self.config.target_page_elements() * 4,
))?;
self.target_cache.page_tables[seq_id].push(phys);
}
let phys_page = self.target_cache.page_tables[seq_id][logical_page];
let page_elements = self.config.target_page_elements();
let token_stride = self.target_cache.num_heads * self.target_cache.head_dim;
let k_offset = phys_page * page_elements + offset_in_page * token_stride;
let v_offset = k_offset;
self.target_cache.seq_positions[seq_id] += 1;
self.target_cache.verified_position += 1;
Ok((k_offset, v_offset))
}
pub fn rollback_to_checkpoint(&mut self, seq_id: usize) -> Result<usize, DnnError> {
if seq_id >= self.draft_cache.seq_positions.len() {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.draft_cache.seq_positions.len(),
)));
}
let cp = self
.checkpoint
.take()
.ok_or_else(|| DnnError::InvalidArgument("no checkpoint to rollback to".into()))?;
let pages_before = self.draft_cache.allocated_pages();
self.draft_cache.seq_positions = cp.draft_positions;
self.draft_cache.page_tables = cp.draft_page_tables;
self.draft_cache.free_pages = cp.draft_free_pages;
let pages_after = self.draft_cache.allocated_pages();
let freed = pages_before.saturating_sub(pages_after);
self.total_rollbacks += 1;
Ok(freed)
}
pub fn accept_tokens(
&mut self,
seq_id: usize,
count: usize,
) -> Result<VerificationResult, DnnError> {
if seq_id >= self.draft_cache.seq_positions.len() {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.draft_cache.seq_positions.len(),
)));
}
let cp = self.checkpoint.as_ref().ok_or_else(|| {
DnnError::InvalidArgument("no checkpoint: call create_checkpoint first".into())
})?;
let draft_start = cp.draft_positions[seq_id];
let draft_end = self.draft_cache.seq_positions[seq_id];
let total_drafted = draft_end.saturating_sub(draft_start);
if count > total_drafted {
return Err(DnnError::InvalidArgument(format!(
"cannot accept {} tokens when only {} were drafted",
count, total_drafted,
)));
}
for _ in 0..count {
self.append_target_kv(seq_id)?;
}
let acceptance_rate = if total_drafted > 0 {
count as f32 / total_drafted as f32
} else {
0.0
};
self.acceptance_sum += acceptance_rate as f64;
self.acceptance_count += 1;
let rolled_back_pages = if count < total_drafted {
let cp_clone = self.checkpoint.clone();
let freed = self.rollback_to_checkpoint(seq_id)?;
self.draft_cache.seq_positions[seq_id] = draft_start + count;
self.checkpoint = cp_clone;
freed
} else {
self.checkpoint = None;
0
};
let bonus_token = true;
Ok(VerificationResult {
accepted_count: count,
total_drafted,
acceptance_rate,
rolled_back_pages,
bonus_token,
})
}
pub fn draft_seq_len(&self, seq_id: usize) -> Result<usize, DnnError> {
self.draft_cache
.seq_positions
.get(seq_id)
.copied()
.ok_or_else(|| {
DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.draft_cache.seq_positions.len(),
))
})
}
pub fn target_seq_len(&self, seq_id: usize) -> Result<usize, DnnError> {
self.target_cache
.seq_positions
.get(seq_id)
.copied()
.ok_or_else(|| {
DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.target_cache.seq_positions.len(),
))
})
}
pub fn reset_sequence(&mut self, seq_id: usize) -> Result<(), DnnError> {
if seq_id >= self.draft_cache.seq_positions.len() {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id,
self.draft_cache.seq_positions.len(),
)));
}
for &page in &self.draft_cache.page_tables[seq_id] {
self.draft_cache.free_pages.push(page);
}
self.draft_cache.page_tables[seq_id].clear();
self.draft_cache.seq_positions[seq_id] = 0;
if seq_id < self.target_cache.seq_positions.len() {
for &page in &self.target_cache.page_tables[seq_id] {
self.target_cache.free_pages.push(page);
}
self.target_cache.page_tables[seq_id].clear();
self.target_cache.seq_positions[seq_id] = 0;
}
self.checkpoint = None;
Ok(())
}
#[must_use]
pub fn stats(&self) -> SpeculativeDecodeStats {
let avg_rate = if self.acceptance_count > 0 {
(self.acceptance_sum / self.acceptance_count as f64) as f32
} else {
0.0
};
SpeculativeDecodeStats {
draft_pages_allocated: self.draft_cache.allocated_pages(),
target_pages_allocated: self.target_cache.allocated_pages(),
total_checkpoints_created: self.total_checkpoints,
total_rollbacks: self.total_rollbacks,
average_acceptance_rate: avg_rate,
}
}
#[must_use]
pub fn draft_cache(&self) -> &DraftCacheState {
&self.draft_cache
}
#[must_use]
pub fn target_cache(&self) -> &TargetCacheState {
&self.target_cache
}
#[must_use]
pub fn checkpoint(&self) -> Option<&CacheCheckpoint> {
self.checkpoint.as_ref()
}
}
pub struct SpeculativeDecodePlan {
config: SpeculativeDecodeConfig,
}
impl SpeculativeDecodePlan {
pub fn new(config: SpeculativeDecodeConfig) -> Result<Self, DnnError> {
config.validate()?;
Ok(Self { config })
}
pub fn generate_kv_copy_ptx(&self) -> Result<String, DnnError> {
let kernel_name = "spec_decode_kv_copy";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm80)
.max_threads_per_block(SPEC_BLOCK_SIZE)
.param("src_ptr", PtxType::U64)
.param("dst_ptr", PtxType::U64)
.param("offsets_ptr", PtxType::U64)
.param("num_elements", PtxType::U32)
.body(|b| {
let tid = b.global_thread_id_x();
let n = b.load_param_u32("num_elements");
let tid2 = tid.clone();
b.if_lt_u32(tid, n, |b| {
let src_base = b.load_param_u64("src_ptr");
let dst_base = b.load_param_u64("dst_ptr");
let byte_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shl.b32 {byte_off}, {tid2}, 2;"));
let byte_off_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {byte_off_64}, {byte_off};"));
let src_addr = b.add_u64(src_base, byte_off_64.clone());
let dst_addr = b.add_u64(dst_base, byte_off_64);
let val = b.load_global_f32(src_addr);
b.store_global_f32(dst_addr, val);
});
b.ret();
})
.build()?;
Ok(ptx)
}
pub fn generate_verification_ptx(&self) -> Result<String, DnnError> {
let kernel_name = "spec_decode_verify";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm80)
.max_threads_per_block(SPEC_BLOCK_SIZE)
.param("draft_logits", PtxType::U64)
.param("target_logits", PtxType::U64)
.param("accept_mask", PtxType::U64)
.param("vocab_size", PtxType::U32)
.param("num_tokens", PtxType::U32)
.param("threshold", PtxType::F32)
.body(|b| {
let tid = b.global_thread_id_x();
let num_tok = b.load_param_u32("num_tokens");
let tid2 = tid.clone();
b.if_lt_u32(tid, num_tok, |b| {
let draft_ptr = b.load_param_u64("draft_logits");
let target_ptr = b.load_param_u64("target_logits");
let thresh = b.load_param_f32("threshold");
let byte_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shl.b32 {byte_off}, {tid2}, 2;"));
let byte_off_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {byte_off_64}, {byte_off};"));
let d_addr = b.add_u64(draft_ptr, byte_off_64.clone());
let t_addr = b.add_u64(target_ptr, byte_off_64.clone());
let d_val = b.load_global_f32(d_addr);
let t_val = b.load_global_f32(t_addr);
let diff = b.sub_f32(t_val, d_val);
let abs_diff = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("abs.f32 {abs_diff}, {diff};"));
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.le.f32 {pred}, {abs_diff}, {thresh};"));
let result = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("selp.u32 {result}, 1, 0, {pred};"));
let mask_ptr = b.load_param_u64("accept_mask");
let out_addr = b.add_u64(mask_ptr, byte_off_64);
b.store_global_u32(out_addr, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
pub fn generate_rejection_sampling_ptx(&self) -> Result<String, DnnError> {
let kernel_name = "spec_decode_rejection_sample";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm80)
.max_threads_per_block(SPEC_BLOCK_SIZE)
.param("draft_probs", PtxType::U64)
.param("target_probs", PtxType::U64)
.param("random_vals", PtxType::U64)
.param("output_tokens", PtxType::U64)
.param("num_tokens", PtxType::U32)
.body(|b| {
let tid = b.global_thread_id_x();
let n = b.load_param_u32("num_tokens");
let tid2 = tid.clone();
b.if_lt_u32(tid, n, |b| {
let draft_ptr = b.load_param_u64("draft_probs");
let target_ptr = b.load_param_u64("target_probs");
let rand_ptr = b.load_param_u64("random_vals");
let byte_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shl.b32 {byte_off}, {tid2}, 2;"));
let byte_off_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {byte_off_64}, {byte_off};"));
let d_addr = b.add_u64(draft_ptr, byte_off_64.clone());
let t_addr = b.add_u64(target_ptr, byte_off_64.clone());
let r_addr = b.add_u64(rand_ptr, byte_off_64.clone());
let p_draft = b.load_global_f32(d_addr);
let p_target = b.load_global_f32(t_addr);
let rand_val = b.load_global_f32(r_addr);
let ratio = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("div.rn.f32 {ratio}, {p_target}, {p_draft};"));
let one = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {one}, 0f3F800000;"));
let clamped = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("min.f32 {clamped}, {ratio}, {one};"));
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lt.f32 {pred}, {rand_val}, {clamped};"));
let result = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("selp.u32 {result}, 1, 0, {pred};"));
let out_ptr = b.load_param_u64("output_tokens");
let out_addr = b.add_u64(out_ptr, byte_off_64);
b.store_global_u32(out_addr, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
#[must_use]
pub fn config(&self) -> &SpeculativeDecodeConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> SpeculativeDecodeConfig {
SpeculativeDecodeConfig {
draft_num_layers: 6,
draft_num_heads: 8,
draft_head_dim: 64,
target_num_layers: 32,
target_num_heads: 32,
target_head_dim: 128,
max_draft_tokens: 5,
page_size: 16,
max_pages: 64,
acceptance_threshold: 0.9,
}
}
#[test]
fn config_validation_ok() {
let cfg = test_config();
assert!(cfg.validate().is_ok());
}
#[test]
fn config_validation_zero_draft_layers() {
let mut cfg = test_config();
cfg.draft_num_layers = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn config_validation_zero_page_size() {
let mut cfg = test_config();
cfg.page_size = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn config_validation_bad_threshold() {
let mut cfg = test_config();
cfg.acceptance_threshold = 1.5;
assert!(cfg.validate().is_err());
}
#[test]
fn config_page_elements() {
let cfg = test_config();
assert_eq!(cfg.draft_page_elements(), 8 * 16 * 64);
assert_eq!(cfg.target_page_elements(), 32 * 16 * 128);
}
#[test]
fn create_manager_initial_state() {
let mgr = SpeculativeKvManager::new(test_config());
assert!(mgr.is_ok());
let mgr = mgr.expect("tested above");
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 0);
assert_eq!(mgr.target_seq_len(0).expect("seq 0"), 0);
assert!(mgr.checkpoint().is_none());
let s = mgr.stats();
assert_eq!(s.draft_pages_allocated, 0);
assert_eq!(s.target_pages_allocated, 0);
assert_eq!(s.total_checkpoints_created, 0);
assert_eq!(s.total_rollbacks, 0);
}
#[test]
fn checkpoint_and_rollback() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
mgr.append_draft_kv(0).expect("append 0");
mgr.append_draft_kv(0).expect("append 1");
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 2);
mgr.create_checkpoint(0).expect("checkpoint");
assert!(mgr.checkpoint().is_some());
mgr.append_draft_kv(0).expect("append 2");
mgr.append_draft_kv(0).expect("append 3");
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 4);
let freed = mgr.rollback_to_checkpoint(0).expect("rollback");
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 2);
let _ = freed; assert!(mgr.checkpoint().is_none());
let s = mgr.stats();
assert_eq!(s.total_checkpoints_created, 1);
assert_eq!(s.total_rollbacks, 1);
}
#[test]
fn append_draft_kv_increments_position() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
let (k, v) = mgr.append_draft_kv(0).expect("append");
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 1);
let _ = (k, v); }
#[test]
fn append_target_kv_increments_position() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
let (k, v) = mgr.append_target_kv(0).expect("append");
assert_eq!(mgr.target_seq_len(0).expect("seq 0"), 1);
let _ = (k, v); }
#[test]
fn accept_tokens_full() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
mgr.create_checkpoint(0).expect("checkpoint");
for _ in 0..5 {
mgr.append_draft_kv(0).expect("draft");
}
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 5);
let result = mgr.accept_tokens(0, 5).expect("accept all");
assert_eq!(result.accepted_count, 5);
assert_eq!(result.total_drafted, 5);
assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
assert!(result.bonus_token);
assert_eq!(result.rolled_back_pages, 0);
assert_eq!(mgr.target_seq_len(0).expect("seq 0"), 5);
}
#[test]
fn accept_tokens_partial() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
mgr.create_checkpoint(0).expect("checkpoint");
for _ in 0..5 {
mgr.append_draft_kv(0).expect("draft");
}
let result = mgr.accept_tokens(0, 2).expect("accept partial");
assert_eq!(result.accepted_count, 2);
assert_eq!(result.total_drafted, 5);
assert!((result.acceptance_rate - 0.4).abs() < f32::EPSILON);
assert!(result.bonus_token);
assert_eq!(mgr.target_seq_len(0).expect("seq 0"), 2);
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 2);
}
#[test]
fn reset_sequence_clears_state() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
for _ in 0..10 {
mgr.append_draft_kv(0).expect("draft");
}
for _ in 0..3 {
mgr.append_target_kv(0).expect("target");
}
mgr.reset_sequence(0).expect("reset");
assert_eq!(mgr.draft_seq_len(0).expect("seq 0"), 0);
assert_eq!(mgr.target_seq_len(0).expect("seq 0"), 0);
assert!(mgr.checkpoint().is_none());
}
#[test]
fn stats_tracking() {
let mut mgr = SpeculativeKvManager::new(test_config()).expect("valid config");
mgr.create_checkpoint(0).expect("cp");
for _ in 0..3 {
mgr.append_draft_kv(0).expect("draft");
}
mgr.accept_tokens(0, 3).expect("accept");
mgr.create_checkpoint(0).expect("cp");
for _ in 0..4 {
mgr.append_draft_kv(0).expect("draft");
}
mgr.accept_tokens(0, 1).expect("accept");
let s = mgr.stats();
assert_eq!(s.total_checkpoints_created, 2);
assert!(s.total_rollbacks >= 1);
assert!(s.average_acceptance_rate > 0.0);
assert!(s.average_acceptance_rate < 1.0);
}
#[test]
fn kv_copy_ptx_generation() {
let plan = SpeculativeDecodePlan::new(test_config()).expect("plan");
let ptx = plan.generate_kv_copy_ptx().expect("ptx");
assert!(ptx.contains(".entry spec_decode_kv_copy"));
assert!(ptx.contains("%param_src_ptr"));
assert!(ptx.contains("%param_dst_ptr"));
assert!(ptx.contains("%param_num_elements"));
}
#[test]
fn verification_ptx_generation() {
let plan = SpeculativeDecodePlan::new(test_config()).expect("plan");
let ptx = plan.generate_verification_ptx().expect("ptx");
assert!(ptx.contains(".entry spec_decode_verify"));
assert!(ptx.contains("%param_draft_logits"));
assert!(ptx.contains("%param_target_logits"));
assert!(ptx.contains("%param_threshold"));
}
#[test]
fn rejection_sampling_ptx_generation() {
let plan = SpeculativeDecodePlan::new(test_config()).expect("plan");
let ptx = plan.generate_rejection_sampling_ptx().expect("ptx");
assert!(ptx.contains(".entry spec_decode_rejection_sample"));
assert!(ptx.contains("%param_draft_probs"));
assert!(ptx.contains("%param_target_probs"));
assert!(ptx.contains("%param_random_vals"));
}
#[test]
fn max_draft_tokens_one() {
let mut cfg = test_config();
cfg.max_draft_tokens = 1;
let mut mgr = SpeculativeKvManager::new(cfg).expect("valid config");
mgr.create_checkpoint(0).expect("cp");
mgr.append_draft_kv(0).expect("draft");
let result = mgr.accept_tokens(0, 1).expect("accept");
assert_eq!(result.accepted_count, 1);
assert_eq!(result.total_drafted, 1);
assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
}
}