use std::collections::VecDeque;
use crate::error::{DnnError, DnnResult};
pub type RequestId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
}
#[derive(Debug, Clone)]
pub struct InferenceRequest {
pub request_id: RequestId,
pub sequence_length: usize,
pub max_new_tokens: usize,
pub priority: Priority,
pub arrival_time_ns: u64,
pub deadline_ns: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct BatchSlot {
pub slot_id: usize,
pub request_id: RequestId,
pub current_seq_len: usize,
pub max_seq_len: usize,
pub is_prefill: bool,
pub is_active: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingPolicy {
Fcfs,
ShortestJobFirst,
PriorityBased,
DeadlineAware,
Orca,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PreemptionPolicy {
Recompute,
Swap,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_total_tokens: usize,
pub max_sequence_length: usize,
pub prefill_batch_size: usize,
pub decode_batch_size: usize,
pub scheduling_policy: SchedulingPolicy,
}
#[derive(Debug, Clone)]
pub struct BatchDecision {
pub prefill_requests: Vec<RequestId>,
pub decode_requests: Vec<RequestId>,
pub preempted: Vec<RequestId>,
pub total_tokens: usize,
}
#[derive(Debug)]
struct BatchState {
active_slots: Vec<BatchSlot>,
total_tokens: usize,
prefill_queue: VecDeque<InferenceRequest>,
decode_queue: VecDeque<RequestId>,
preempted_queue: VecDeque<InferenceRequest>,
}
impl BatchState {
fn new() -> Self {
Self {
active_slots: Vec::new(),
total_tokens: 0,
prefill_queue: VecDeque::new(),
decode_queue: VecDeque::new(),
preempted_queue: VecDeque::new(),
}
}
}
#[derive(Debug)]
pub struct ContinuousBatcher {
config: BatchConfig,
state: BatchState,
next_slot_id: usize,
completed_count: u64,
}
impl ContinuousBatcher {
pub fn new(config: BatchConfig) -> Self {
Self {
config,
state: BatchState::new(),
next_slot_id: 0,
completed_count: 0,
}
}
pub fn add_request(&mut self, request: InferenceRequest) -> DnnResult<RequestId> {
if request.sequence_length == 0 {
return Err(DnnError::InvalidArgument(
"sequence_length must be > 0".into(),
));
}
if request.sequence_length > self.config.max_sequence_length {
return Err(DnnError::InvalidArgument(format!(
"sequence_length {} exceeds max_sequence_length {}",
request.sequence_length, self.config.max_sequence_length
)));
}
let id = request.request_id;
self.state.prefill_queue.push_back(request);
Ok(id)
}
pub fn step(&mut self) -> DnnResult<BatchDecision> {
let mut decision = BatchDecision {
prefill_requests: Vec::new(),
decode_requests: Vec::new(),
preempted: Vec::new(),
total_tokens: 0,
};
let decode_ids: Vec<RequestId> = self
.state
.active_slots
.iter()
.filter(|s| s.is_active && !s.is_prefill)
.map(|s| s.request_id)
.collect();
let decode_count = decode_ids.len().min(self.config.decode_batch_size);
let decode_tokens: usize = self
.state
.active_slots
.iter()
.filter(|s| s.is_active && !s.is_prefill)
.take(decode_count)
.map(|s| s.current_seq_len + 1) .sum();
decision.decode_requests = decode_ids.into_iter().take(decode_count).collect();
self.sort_prefill_queue();
let mut prefill_budget = self
.config
.prefill_batch_size
.min(self.config.max_total_tokens.saturating_sub(decode_tokens));
let mut admitted = Vec::new();
while !self.state.prefill_queue.is_empty()
&& self.state.active_slots.len() + admitted.len() < self.config.max_batch_size
{
let req = match self.state.prefill_queue.front() {
Some(r) => r,
None => break,
};
if req.sequence_length > prefill_budget {
break;
}
let req = self
.state
.prefill_queue
.pop_front()
.ok_or_else(|| DnnError::InvalidArgument("empty queue".into()))?;
prefill_budget = prefill_budget.saturating_sub(req.sequence_length);
let slot = BatchSlot {
slot_id: self.next_slot_id,
request_id: req.request_id,
current_seq_len: req.sequence_length,
max_seq_len: req.sequence_length + req.max_new_tokens,
is_prefill: true,
is_active: true,
};
self.next_slot_id += 1;
decision.prefill_requests.push(req.request_id);
admitted.push(slot);
}
for slot in &mut admitted {
slot.is_prefill = false;
}
self.state.active_slots.extend(admitted);
for slot in &mut self.state.active_slots {
if slot.is_active && !slot.is_prefill {
slot.current_seq_len = slot.current_seq_len.saturating_add(1);
}
}
decision.total_tokens = self
.state
.active_slots
.iter()
.filter(|s| s.is_active)
.map(|s| s.current_seq_len)
.sum();
self.state.total_tokens = decision.total_tokens;
Ok(decision)
}
pub fn complete_request(&mut self, request_id: RequestId) -> DnnResult<()> {
let pos = self
.state
.active_slots
.iter()
.position(|s| s.request_id == request_id)
.ok_or_else(|| {
DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
})?;
let slot = &self.state.active_slots[pos];
self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
self.state.active_slots.remove(pos);
self.state.decode_queue.retain(|id| *id != request_id);
self.completed_count += 1;
Ok(())
}
pub fn preempt(&mut self, request_id: RequestId) -> DnnResult<()> {
let pos = self
.state
.active_slots
.iter()
.position(|s| s.request_id == request_id)
.ok_or_else(|| {
DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
})?;
let slot = self.state.active_slots.remove(pos);
self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
self.state.decode_queue.retain(|id| *id != request_id);
let preempted_req = InferenceRequest {
request_id,
sequence_length: slot.current_seq_len,
max_new_tokens: slot.max_seq_len.saturating_sub(slot.current_seq_len),
priority: Priority::Normal,
arrival_time_ns: 0,
deadline_ns: None,
};
self.state.preempted_queue.push_back(preempted_req);
Ok(())
}
pub fn active_requests(&self) -> usize {
self.state
.active_slots
.iter()
.filter(|s| s.is_active)
.count()
}
pub fn pending_requests(&self) -> usize {
self.state.prefill_queue.len() + self.state.preempted_queue.len()
}
pub fn throughput_tokens_per_step(&self) -> usize {
self.state.total_tokens
}
fn sort_prefill_queue(&mut self) {
let queue = &mut self.state.prefill_queue;
let policy = self.config.scheduling_policy;
let mut vec: Vec<InferenceRequest> = queue.drain(..).collect();
match policy {
SchedulingPolicy::Fcfs => {
vec.sort_by_key(|r| r.arrival_time_ns);
}
SchedulingPolicy::ShortestJobFirst => {
vec.sort_by_key(|r| r.max_new_tokens);
}
SchedulingPolicy::PriorityBased => {
vec.sort_by(|a, b| {
b.priority
.cmp(&a.priority)
.then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
});
}
SchedulingPolicy::DeadlineAware => {
vec.sort_by(|a, b| {
let da = a.deadline_ns.unwrap_or(u64::MAX);
let db = b.deadline_ns.unwrap_or(u64::MAX);
da.cmp(&db).then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
});
}
SchedulingPolicy::Orca => {
vec.sort_by_key(|r| r.arrival_time_ns);
}
}
*queue = VecDeque::from(vec);
}
}
#[derive(Debug)]
pub struct TokenBudgetAllocator {
max_total_tokens: usize,
allocated: usize,
}
impl TokenBudgetAllocator {
pub fn new(max_total_tokens: usize) -> Self {
Self {
max_total_tokens,
allocated: 0,
}
}
pub fn allocate_prefill(&mut self, seq_len: usize) -> Option<usize> {
if self.allocated + seq_len > self.max_total_tokens {
return None;
}
let slot = self.allocated;
self.allocated += seq_len;
Some(slot)
}
pub fn allocate_decode(&mut self, count: usize) -> usize {
let remaining = self.max_total_tokens.saturating_sub(self.allocated);
let actual = count.min(remaining);
self.allocated += actual;
actual
}
pub fn release(&mut self, tokens: usize) {
self.allocated = self.allocated.saturating_sub(tokens);
}
pub fn utilization(&self) -> f64 {
if self.max_total_tokens == 0 {
return 0.0;
}
self.allocated as f64 / self.max_total_tokens as f64
}
}
#[derive(Debug)]
pub struct PagedKvManager {
num_blocks: usize,
block_size: usize,
free_map: Vec<bool>,
ref_counts: Vec<usize>,
}
impl PagedKvManager {
pub fn new(num_blocks: usize, block_size: usize) -> Self {
Self {
num_blocks,
block_size,
free_map: vec![true; num_blocks],
ref_counts: vec![0; num_blocks],
}
}
pub fn allocate(&mut self, num_tokens: usize) -> DnnResult<Vec<usize>> {
if self.block_size == 0 {
return Err(DnnError::InvalidArgument("block_size is 0".into()));
}
let blocks_needed = num_tokens.div_ceil(self.block_size);
if !self.can_allocate(num_tokens) {
return Err(DnnError::InvalidArgument(format!(
"not enough free blocks: need {blocks_needed}, have {}",
self.free_block_count()
)));
}
let mut ids = Vec::with_capacity(blocks_needed);
for (i, free) in self.free_map.iter_mut().enumerate() {
if ids.len() >= blocks_needed {
break;
}
if *free {
*free = false;
self.ref_counts[i] = 1;
ids.push(i);
}
}
Ok(ids)
}
pub fn free(&mut self, block_ids: &[usize]) {
for &id in block_ids {
if id < self.num_blocks {
self.ref_counts[id] = self.ref_counts[id].saturating_sub(1);
if self.ref_counts[id] == 0 {
self.free_map[id] = true;
}
}
}
}
pub fn copy_on_write(&mut self, block_id: usize) -> DnnResult<usize> {
if block_id >= self.num_blocks {
return Err(DnnError::InvalidArgument(format!(
"block_id {block_id} out of range (max {})",
self.num_blocks
)));
}
let new_id =
self.free_map.iter().position(|&free| free).ok_or_else(|| {
DnnError::InvalidArgument("no free blocks for copy-on-write".into())
})?;
self.free_map[new_id] = false;
self.ref_counts[new_id] = 1;
self.ref_counts[block_id] = self.ref_counts[block_id].saturating_sub(1);
if self.ref_counts[block_id] == 0 {
self.free_map[block_id] = true;
}
Ok(new_id)
}
pub fn usage(&self) -> (usize, usize) {
let used = self.free_map.iter().filter(|&&free| !free).count();
(used, self.num_blocks)
}
pub fn can_allocate(&self, num_tokens: usize) -> bool {
if self.block_size == 0 {
return false;
}
let needed = num_tokens.div_ceil(self.block_size);
self.free_block_count() >= needed
}
fn free_block_count(&self) -> usize {
self.free_map.iter().filter(|&&f| f).count()
}
}
#[derive(Debug, Clone)]
pub struct LcgRng {
state: u64,
}
impl LcgRng {
const MUL: u64 = 6_364_136_223_846_793_005;
const ADD: u64 = 1_442_695_040_888_963_407;
#[must_use]
pub fn new(seed: u64) -> Self {
Self {
state: seed
.wrapping_mul(0x9E37_79B9_7F4A_7C15)
.wrapping_add(Self::ADD),
}
}
#[inline]
pub fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
self.state
}
#[inline]
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
pub fn sample_categorical(&mut self, weights: &[f64]) -> Option<usize> {
let total: f64 = weights.iter().sum();
if weights.is_empty() || !total.is_finite() || total <= 0.0 {
return None;
}
let threshold = self.next_f64() * total;
let mut acc = 0.0;
for (idx, &w) in weights.iter().enumerate() {
acc += w.max(0.0);
if threshold < acc {
return Some(idx);
}
}
weights.iter().rposition(|&w| w > 0.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpeculativeResult {
pub tokens: Vec<u32>,
pub accepted: usize,
pub rejected: usize,
}
#[derive(Debug)]
pub struct SpeculativeDecoder {
draft_length: usize,
rng: LcgRng,
total_proposed: u64,
total_accepted: u64,
rounds: u64,
}
impl SpeculativeDecoder {
const DEFAULT_SEED: u64 = 0x5350_4543;
#[must_use]
pub fn new(draft_length: usize) -> Self {
Self::with_seed(draft_length, Self::DEFAULT_SEED)
}
#[must_use]
pub fn with_seed(draft_length: usize, seed: u64) -> Self {
Self {
draft_length,
rng: LcgRng::new(seed),
total_proposed: 0,
total_accepted: 0,
rounds: 0,
}
}
#[must_use]
pub fn draft_length(&self) -> usize {
self.draft_length
}
pub fn propose_tokens(&mut self, draft_probs: &[Vec<f64>]) -> DnnResult<Vec<DraftedToken>> {
let count = draft_probs.len().min(self.draft_length);
let mut drafted = Vec::with_capacity(count);
for (position, dist) in draft_probs.iter().take(count).enumerate() {
let token = self.rng.sample_categorical(dist).ok_or_else(|| {
DnnError::InvalidArgument(format!(
"draft distribution at position {position} has no positive, finite mass"
))
})?;
let total: f64 = dist.iter().map(|p| p.max(0.0)).sum();
let draft_prob = dist[token].max(0.0) / total;
drafted.push(DraftedToken {
token_id: token as u32,
draft_prob,
});
}
Ok(drafted)
}
pub fn verify_and_accept(
&mut self,
drafted: &[DraftedToken],
target_dists: &[Vec<f64>],
) -> DnnResult<SpeculativeResult> {
let gamma = drafted.len();
if target_dists.len() <= gamma {
return Err(DnnError::InvalidArgument(format!(
"target_dists must have at least {} rows (one per drafted token \
plus a bonus row), got {}",
gamma + 1,
target_dists.len(),
)));
}
let mut tokens = Vec::with_capacity(gamma + 1);
for (i, draft) in drafted.iter().enumerate() {
let token = draft.token_id as usize;
let target_dist = &target_dists[i];
let target_total: f64 = target_dist.iter().map(|p| p.max(0.0)).sum();
let p_target = target_dist
.get(token)
.copied()
.ok_or_else(|| {
DnnError::InvalidArgument(format!(
"target distribution at position {i} (len {}) does not \
contain drafted token id {token}",
target_dist.len(),
))
})?
.max(0.0);
let p_target = if target_total > 0.0 {
p_target / target_total
} else {
0.0
};
let accept_ratio = if draft.draft_prob > 0.0 {
(p_target / draft.draft_prob).min(1.0)
} else {
0.0
};
let r = self.rng.next_f64();
if r < accept_ratio {
tokens.push(draft.token_id);
continue;
}
let residual = Self::residual_distribution(target_dist, drafted, i);
let correction = self.rng.sample_categorical(&residual).ok_or_else(|| {
DnnError::InvalidArgument(format!(
"residual distribution at position {i} has no positive mass"
))
})?;
tokens.push(correction as u32);
let accepted = i;
self.record(gamma, accepted);
return Ok(SpeculativeResult {
tokens,
accepted,
rejected: 1,
});
}
let bonus_dist = &target_dists[gamma];
let bonus = self.rng.sample_categorical(bonus_dist).ok_or_else(|| {
DnnError::InvalidArgument(
"bonus target distribution has no positive, finite mass".into(),
)
})?;
tokens.push(bonus as u32);
self.record(gamma, gamma);
Ok(SpeculativeResult {
tokens,
accepted: gamma,
rejected: 0,
})
}
fn residual_distribution(
target_dist: &[f64],
drafted: &[DraftedToken],
position: usize,
) -> Vec<f64> {
let target_total: f64 = target_dist.iter().map(|p| p.max(0.0)).sum();
let drafted_token = drafted[position].token_id as usize;
let draft_prob = drafted[position].draft_prob;
let mut residual: Vec<f64> = Vec::with_capacity(target_dist.len());
for (idx, &t) in target_dist.iter().enumerate() {
let p_target = if target_total > 0.0 {
t.max(0.0) / target_total
} else {
0.0
};
let p_draft = if idx == drafted_token {
draft_prob.max(0.0)
} else {
0.0
};
residual.push((p_target - p_draft).max(0.0));
}
let residual_sum: f64 = residual.iter().sum();
if residual_sum <= 0.0 {
return target_dist.iter().map(|p| p.max(0.0)).collect();
}
residual
}
fn record(&mut self, proposed: usize, accepted: usize) {
self.total_proposed += proposed as u64;
self.total_accepted += accepted as u64;
self.rounds += 1;
}
#[must_use]
pub fn acceptance_rate(&self) -> f64 {
if self.total_proposed == 0 {
return 0.0;
}
self.total_accepted as f64 / self.total_proposed as f64
}
#[must_use]
pub fn total_proposed(&self) -> u64 {
self.total_proposed
}
#[must_use]
pub fn total_accepted(&self) -> u64 {
self.total_accepted
}
#[must_use]
pub fn rounds(&self) -> u64 {
self.rounds
}
#[must_use]
pub fn mean_tokens_per_round(&self) -> f64 {
if self.rounds == 0 {
return 0.0;
}
(self.total_accepted + self.rounds) as f64 / self.rounds as f64
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DraftedToken {
pub token_id: u32,
pub draft_prob: f64,
}
#[derive(Debug)]
pub struct BatchMetrics {
steps: Vec<(usize, usize, u64)>,
ttft_samples: Vec<u64>,
}
impl BatchMetrics {
pub fn new() -> Self {
Self {
steps: Vec::new(),
ttft_samples: Vec::new(),
}
}
pub fn record_step(&mut self, prefill_tokens: usize, decode_tokens: usize, latency_us: u64) {
self.steps.push((prefill_tokens, decode_tokens, latency_us));
}
pub fn record_ttft(&mut self, ttft_us: u64) {
self.ttft_samples.push(ttft_us);
}
pub fn avg_prefill_latency(&self) -> f64 {
let prefills: Vec<u64> = self
.steps
.iter()
.filter(|(p, _, _)| *p > 0)
.map(|(_, _, l)| *l)
.collect();
if prefills.is_empty() {
return 0.0;
}
prefills.iter().sum::<u64>() as f64 / prefills.len() as f64
}
pub fn avg_decode_latency(&self) -> f64 {
let decodes: Vec<u64> = self
.steps
.iter()
.filter(|(_, d, _)| *d > 0)
.map(|(_, _, l)| *l)
.collect();
if decodes.is_empty() {
return 0.0;
}
decodes.iter().sum::<u64>() as f64 / decodes.len() as f64
}
pub fn avg_batch_size(&self) -> f64 {
if self.steps.is_empty() {
return 0.0;
}
let total: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
total as f64 / self.steps.len() as f64
}
pub fn token_throughput(&self) -> f64 {
if self.steps.is_empty() {
return 0.0;
}
let total_tokens: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
let total_us: u64 = self.steps.iter().map(|(_, _, l)| l).sum();
if total_us == 0 {
return 0.0;
}
total_tokens as f64 / (total_us as f64 / 1_000_000.0)
}
pub fn time_to_first_token_p50(&self) -> f64 {
if self.ttft_samples.is_empty() {
return 0.0;
}
let mut sorted = self.ttft_samples.clone();
sorted.sort_unstable();
let mid = sorted.len() / 2;
if sorted.len() % 2 == 0 && sorted.len() >= 2 {
(sorted[mid - 1] + sorted[mid]) as f64 / 2.0
} else {
sorted[mid] as f64
}
}
pub fn format_report(&self) -> String {
format!(
"BatchMetrics Report\n\
====================\n\
Steps recorded : {}\n\
Avg prefill latency : {:.1} us\n\
Avg decode latency : {:.1} us\n\
Avg batch size : {:.1} tokens/step\n\
Token throughput : {:.0} tokens/s\n\
TTFT p50 : {:.1} us\n\
TTFT samples : {}",
self.steps.len(),
self.avg_prefill_latency(),
self.avg_decode_latency(),
self.avg_batch_size(),
self.token_throughput(),
self.time_to_first_token_p50(),
self.ttft_samples.len(),
)
}
}
impl Default for BatchMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> BatchConfig {
BatchConfig {
max_batch_size: 8,
max_total_tokens: 4096,
max_sequence_length: 2048,
prefill_batch_size: 1024,
decode_batch_size: 8,
scheduling_policy: SchedulingPolicy::Fcfs,
}
}
fn make_request(id: RequestId, seq_len: usize, max_new: usize) -> InferenceRequest {
InferenceRequest {
request_id: id,
sequence_length: seq_len,
max_new_tokens: max_new,
priority: Priority::Normal,
arrival_time_ns: id * 1000,
deadline_ns: None,
}
}
#[test]
fn test_add_single_request() {
let mut batcher = ContinuousBatcher::new(default_config());
let req = make_request(1, 128, 64);
let id = batcher.add_request(req).expect("should succeed");
assert_eq!(id, 1);
assert_eq!(batcher.pending_requests(), 1);
assert_eq!(batcher.active_requests(), 0);
}
#[test]
fn test_batch_step_mixed_prefill_decode() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(1, 64, 32)).expect("add 1");
let d1 = batcher.step().expect("step 1");
assert_eq!(d1.prefill_requests.len(), 1);
batcher.add_request(make_request(2, 32, 16)).expect("add 2");
let d2 = batcher.step().expect("step 2");
assert!(!d2.decode_requests.is_empty(), "should have decode slots");
assert!(!d2.prefill_requests.is_empty(), "should have prefill slots");
}
#[test]
fn test_token_budget_allocation_release() {
let mut alloc = TokenBudgetAllocator::new(1024);
let slot = alloc.allocate_prefill(512);
assert!(slot.is_some());
assert!((alloc.utilization() - 0.5).abs() < 1e-9);
assert!(alloc.allocate_prefill(600).is_none());
alloc.release(256);
assert!((alloc.utilization() - 0.25).abs() < 1e-9);
}
#[test]
fn test_paged_kv_allocation_free() {
let mut mgr = PagedKvManager::new(16, 64);
let blocks = mgr.allocate(128).expect("allocate 128");
assert_eq!(blocks.len(), 2);
let (used, total) = mgr.usage();
assert_eq!(used, 2);
assert_eq!(total, 16);
mgr.free(&blocks);
let (used, _) = mgr.usage();
assert_eq!(used, 0);
}
#[test]
fn test_copy_on_write() {
let mut mgr = PagedKvManager::new(4, 64);
let blocks = mgr.allocate(64).expect("allocate");
assert_eq!(blocks.len(), 1);
let orig = blocks[0];
mgr.ref_counts[orig] = 2;
let new_id = mgr.copy_on_write(orig).expect("cow");
assert_ne!(new_id, orig);
assert!(!mgr.free_map[orig]);
assert_eq!(mgr.ref_counts[orig], 1);
assert_eq!(mgr.ref_counts[new_id], 1);
}
#[test]
fn test_continuous_batching_completion() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(10, 64, 8)).expect("add");
let _ = batcher.step().expect("step");
assert_eq!(batcher.active_requests(), 1);
batcher.complete_request(10).expect("complete");
assert_eq!(batcher.active_requests(), 0);
}
#[test]
fn test_preemption() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(20, 64, 16)).expect("add");
let _ = batcher.step().expect("step");
assert_eq!(batcher.active_requests(), 1);
batcher.preempt(20).expect("preempt");
assert_eq!(batcher.active_requests(), 0);
assert_eq!(batcher.pending_requests(), 1);
}
#[test]
fn test_fcfs_scheduling_order() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(3, 32, 8)).expect("add 3");
batcher.add_request(make_request(1, 32, 8)).expect("add 1");
batcher.add_request(make_request(2, 32, 8)).expect("add 2");
let d = batcher.step().expect("step");
assert_eq!(d.prefill_requests, vec![1, 2, 3]);
}
#[test]
fn test_priority_based_scheduling() {
let mut config = default_config();
config.scheduling_policy = SchedulingPolicy::PriorityBased;
let mut batcher = ContinuousBatcher::new(config);
let mut low = make_request(1, 32, 8);
low.priority = Priority::Low;
low.arrival_time_ns = 100;
let mut high = make_request(2, 32, 8);
high.priority = Priority::High;
high.arrival_time_ns = 200;
let mut normal = make_request(3, 32, 8);
normal.priority = Priority::Normal;
normal.arrival_time_ns = 50;
batcher.add_request(low).expect("add low");
batcher.add_request(high).expect("add high");
batcher.add_request(normal).expect("add normal");
let d = batcher.step().expect("step");
assert_eq!(d.prefill_requests, vec![2, 3, 1]);
}
#[test]
fn test_deadline_aware_scheduling() {
let mut config = default_config();
config.scheduling_policy = SchedulingPolicy::DeadlineAware;
let mut batcher = ContinuousBatcher::new(config);
let mut r1 = make_request(1, 32, 8);
r1.deadline_ns = Some(5000);
let mut r2 = make_request(2, 32, 8);
r2.deadline_ns = Some(1000);
let mut r3 = make_request(3, 32, 8);
r3.deadline_ns = None;
batcher.add_request(r1).expect("add r1");
batcher.add_request(r2).expect("add r2");
batcher.add_request(r3).expect("add r3");
let d = batcher.step().expect("step");
assert_eq!(d.prefill_requests, vec![2, 1, 3]);
}
#[test]
fn test_speculative_decoding_propose_samples_draft() {
let mut spec = SpeculativeDecoder::with_seed(3, 12345);
let draft_probs = vec![
vec![0.0, 0.0, 1.0, 0.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
];
let drafted = spec.propose_tokens(&draft_probs).expect("propose");
assert_eq!(drafted.len(), 3);
assert_eq!(drafted[0].token_id, 2);
assert_eq!(drafted[1].token_id, 0);
assert_eq!(drafted[2].token_id, 3);
for d in &drafted {
assert!((d.draft_prob - 1.0).abs() < 1e-12);
}
}
#[test]
fn test_speculative_decoding_propose_caps_and_normalises() {
let mut spec = SpeculativeDecoder::with_seed(2, 99);
let draft_probs = vec![
vec![1.0, 3.0],
vec![3.0, 1.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
];
let drafted = spec.propose_tokens(&draft_probs).expect("propose");
assert_eq!(drafted.len(), 2, "draft_length caps the count");
for d in &drafted {
assert!((0.0..=1.0).contains(&d.draft_prob));
let p = d.draft_prob;
assert!(
(p - 0.25).abs() < 1e-12 || (p - 0.75).abs() < 1e-12,
"unexpected normalised prob {p}"
);
}
}
#[test]
fn test_speculative_decoding_propose_rejects_zero_dist() {
let mut spec = SpeculativeDecoder::new(2);
let draft_probs = vec![vec![0.0, 0.0, 0.0]];
assert!(spec.propose_tokens(&draft_probs).is_err());
}
#[test]
fn test_categorical_sampling_matches_distribution() {
let mut rng = LcgRng::new(0x00C0_FFEE);
let weights = [0.1_f64, 0.2, 0.3, 0.4];
let trials = 200_000;
let mut counts = [0u64; 4];
for _ in 0..trials {
let idx = rng.sample_categorical(&weights).expect("sample");
counts[idx] += 1;
}
for (i, &w) in weights.iter().enumerate() {
let freq = counts[i] as f64 / trials as f64;
assert!(
(freq - w).abs() < 0.01,
"category {i}: freq {freq} vs expected {w}"
);
}
}
#[test]
fn test_rejection_sampling_acceptance_probability() {
let trials = 100_000;
let mut accepted_rounds = 0u64;
for seed in 0..trials {
let mut spec = SpeculativeDecoder::with_seed(1, seed);
let drafted = vec![DraftedToken {
token_id: 0,
draft_prob: 0.8,
}];
let target = vec![vec![0.4, 0.6], vec![0.5, 0.5]];
let res = spec.verify_and_accept(&drafted, &target).expect("verify");
if res.accepted == 1 {
accepted_rounds += 1;
}
}
let rate = accepted_rounds as f64 / trials as f64;
assert!(
(rate - 0.5).abs() < 0.01,
"acceptance rate {rate} should be ~0.5"
);
}
#[test]
fn test_rejection_sampling_always_accepts_when_target_ge_draft() {
for seed in 0..2000 {
let mut spec = SpeculativeDecoder::with_seed(1, seed);
let drafted = vec![DraftedToken {
token_id: 0,
draft_prob: 0.3,
}];
let target = vec![vec![0.6, 0.4], vec![0.5, 0.5]];
let res = spec.verify_and_accept(&drafted, &target).expect("verify");
assert_eq!(res.accepted, 1, "ratio >= 1 must always accept");
assert_eq!(res.rejected, 0);
}
}
#[test]
fn test_rejection_sampling_rejects_zero_draft_prob() {
let mut spec = SpeculativeDecoder::with_seed(1, 7);
let drafted = vec![DraftedToken {
token_id: 0,
draft_prob: 0.0,
}];
let target = vec![vec![0.9, 0.1], vec![0.5, 0.5]];
let res = spec.verify_and_accept(&drafted, &target).expect("verify");
assert_eq!(res.accepted, 0);
assert_eq!(res.rejected, 1);
assert_eq!(res.tokens.len(), 1);
}
#[test]
fn test_residual_distribution_resampling() {
let trials = 100_000;
let mut counts = [0u64; 3];
for seed in 0..trials {
let mut spec = SpeculativeDecoder::with_seed(1, seed);
let drafted = vec![DraftedToken {
token_id: 0,
draft_prob: 1.0,
}];
let target = vec![vec![0.0, 0.5, 0.5], vec![1.0, 0.0, 0.0]];
let res = spec.verify_and_accept(&drafted, &target).expect("verify");
assert_eq!(res.accepted, 0, "must reject");
let corr = res.tokens[0] as usize;
counts[corr] += 1;
}
let total = trials as f64;
assert_eq!(counts[0], 0, "token 0 has zero residual mass");
assert!((counts[1] as f64 / total - 0.5).abs() < 0.01);
assert!((counts[2] as f64 / total - 0.5).abs() < 0.01);
}
#[test]
fn test_residual_distribution_concentrated() {
for seed in 0..1000 {
let mut spec = SpeculativeDecoder::with_seed(1, seed);
let drafted = vec![DraftedToken {
token_id: 1,
draft_prob: 1.0,
}];
let target = vec![vec![1.0, 0.0], vec![0.5, 0.5]];
let res = spec.verify_and_accept(&drafted, &target).expect("verify");
assert_eq!(res.accepted, 0);
assert_eq!(res.tokens[0], 0, "residual concentrates on token 0");
}
}
#[test]
fn test_residual_distribution_zero_fallback() {
let drafted = [DraftedToken {
token_id: 0,
draft_prob: 1.0,
}];
let target_dist = [1.0_f64, 0.0, 0.0];
let residual = SpeculativeDecoder::residual_distribution(&target_dist, &drafted, 0);
assert_eq!(residual, vec![1.0, 0.0, 0.0]);
}
#[test]
fn test_speculative_draft_equals_target_accepts_all() {
let gamma = 5;
for seed in 0..3000 {
let mut spec = SpeculativeDecoder::with_seed(gamma, seed);
let dist = vec![0.15, 0.25, 0.20, 0.40];
let draft_probs = vec![dist.clone(); gamma];
let drafted = spec.propose_tokens(&draft_probs).expect("propose");
assert_eq!(drafted.len(), gamma);
let target_dists = vec![dist.clone(); gamma + 1];
let res = spec
.verify_and_accept(&drafted, &target_dists)
.expect("verify");
assert_eq!(res.accepted, gamma, "draft==target must accept all");
assert_eq!(res.rejected, 0);
assert_eq!(res.tokens.len(), gamma + 1);
}
}
#[test]
fn test_speculative_accepted_length_distribution() {
let gamma = 4usize;
let mut spec = SpeculativeDecoder::with_seed(gamma, 0xABCD);
let rounds = 5000u64;
let mut sum_accepted = 0u64;
for _ in 0..rounds {
let draft_probs = vec![vec![1.0, 0.0]; gamma];
let drafted = spec.propose_tokens(&draft_probs).expect("propose");
let target_dists = vec![vec![0.7, 0.3]; gamma + 1];
let res = spec
.verify_and_accept(&drafted, &target_dists)
.expect("verify");
assert!(res.accepted <= gamma, "accepted within [0, gamma]");
assert_eq!(res.rejected, usize::from(res.accepted < gamma));
assert_eq!(res.tokens.len(), res.accepted + 1);
sum_accepted += res.accepted as u64;
}
assert_eq!(spec.total_proposed(), rounds * gamma as u64);
assert_eq!(spec.total_accepted(), sum_accepted);
assert_eq!(spec.rounds(), rounds);
let rate = spec.acceptance_rate();
assert!(
(rate - 0.4433).abs() < 0.02,
"acceptance rate {rate} should be ~0.4433"
);
let mtpr = spec.mean_tokens_per_round();
assert!(mtpr >= 1.0 && mtpr <= (gamma + 1) as f64, "mtpr {mtpr}");
}
#[test]
fn test_speculative_verify_rejects_short_target() {
let mut spec = SpeculativeDecoder::new(2);
let drafted = vec![
DraftedToken {
token_id: 0,
draft_prob: 0.5,
},
DraftedToken {
token_id: 1,
draft_prob: 0.5,
},
];
let target = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
assert!(spec.verify_and_accept(&drafted, &target).is_err());
}
#[test]
fn test_speculative_verify_rejects_token_out_of_range() {
let mut spec = SpeculativeDecoder::new(1);
let drafted = vec![DraftedToken {
token_id: 9, draft_prob: 0.5,
}];
let target = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
assert!(spec.verify_and_accept(&drafted, &target).is_err());
}
#[test]
fn test_speculative_empty_draft_emits_bonus() {
let mut spec = SpeculativeDecoder::with_seed(4, 55);
let drafted: Vec<DraftedToken> = Vec::new();
let target = vec![vec![0.0, 1.0, 0.0]];
let res = spec.verify_and_accept(&drafted, &target).expect("verify");
assert_eq!(res.accepted, 0);
assert_eq!(res.rejected, 0);
assert_eq!(res.tokens, vec![1], "bonus drawn from one-hot target");
}
#[test]
fn test_lcg_rng_uniform_and_deterministic() {
let mut a = LcgRng::new(2024);
let mut b = LcgRng::new(2024);
let mut sum = 0.0_f64;
let n = 100_000;
for _ in 0..n {
let va = a.next_f64();
let vb = b.next_f64();
assert_eq!(va, vb, "same seed must yield same stream");
assert!((0.0..1.0).contains(&va));
sum += va;
}
let mean = sum / n as f64;
assert!((mean - 0.5).abs() < 0.01, "uniform mean {mean}");
}
#[test]
fn test_batch_metrics_tracking() {
let mut m = BatchMetrics::new();
m.record_step(128, 0, 500);
m.record_step(0, 8, 100);
m.record_step(64, 4, 300);
assert!((m.avg_prefill_latency() - 400.0).abs() < 1e-9);
assert!((m.avg_decode_latency() - 200.0).abs() < 1e-9);
assert!((m.avg_batch_size() - 68.0).abs() < 1e-9);
assert!(m.token_throughput() > 0.0);
}
#[test]
fn test_max_batch_size_enforcement() {
let mut config = default_config();
config.max_batch_size = 2;
let mut batcher = ContinuousBatcher::new(config);
for i in 0..4 {
batcher.add_request(make_request(i, 32, 8)).expect("add");
}
let d = batcher.step().expect("step");
assert!(d.prefill_requests.len() <= 2);
assert_eq!(batcher.active_requests(), d.prefill_requests.len());
}
#[test]
fn test_queue_management() {
let mut batcher = ContinuousBatcher::new(default_config());
assert_eq!(batcher.pending_requests(), 0);
batcher.add_request(make_request(1, 32, 8)).expect("add");
batcher.add_request(make_request(2, 32, 8)).expect("add");
assert_eq!(batcher.pending_requests(), 2);
let _ = batcher.step().expect("step");
assert_eq!(batcher.pending_requests(), 0);
assert_eq!(batcher.active_requests(), 2);
batcher.complete_request(1).expect("complete");
assert_eq!(batcher.active_requests(), 1);
}
#[test]
fn test_utilization_calculation() {
let mut alloc = TokenBudgetAllocator::new(1000);
assert!((alloc.utilization() - 0.0).abs() < 1e-9);
alloc.allocate_prefill(250);
assert!((alloc.utilization() - 0.25).abs() < 1e-9);
let fitted = alloc.allocate_decode(900);
assert_eq!(fitted, 750);
assert!((alloc.utilization() - 1.0).abs() < 1e-9);
let zero = TokenBudgetAllocator::new(0);
assert!((zero.utilization() - 0.0).abs() < 1e-9);
}
#[test]
fn test_format_report() {
let mut m = BatchMetrics::new();
m.record_step(100, 10, 200);
m.record_step(0, 8, 100);
m.record_ttft(150);
m.record_ttft(250);
let report = m.format_report();
assert!(report.contains("Steps recorded"));
assert!(report.contains("Token throughput"));
assert!(report.contains("TTFT p50"));
}
}