use std::collections::VecDeque;
use crate::error::{DnnError, DnnResult};
pub type RequestId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
}
#[derive(Debug, Clone)]
pub struct InferenceRequest {
pub request_id: RequestId,
pub sequence_length: usize,
pub max_new_tokens: usize,
pub priority: Priority,
pub arrival_time_ns: u64,
pub deadline_ns: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct BatchSlot {
pub slot_id: usize,
pub request_id: RequestId,
pub current_seq_len: usize,
pub max_seq_len: usize,
pub is_prefill: bool,
pub is_active: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingPolicy {
Fcfs,
ShortestJobFirst,
PriorityBased,
DeadlineAware,
Orca,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PreemptionPolicy {
Recompute,
Swap,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_total_tokens: usize,
pub max_sequence_length: usize,
pub prefill_batch_size: usize,
pub decode_batch_size: usize,
pub scheduling_policy: SchedulingPolicy,
}
#[derive(Debug, Clone)]
pub struct BatchDecision {
pub prefill_requests: Vec<RequestId>,
pub decode_requests: Vec<RequestId>,
pub preempted: Vec<RequestId>,
pub total_tokens: usize,
}
#[derive(Debug)]
struct BatchState {
active_slots: Vec<BatchSlot>,
total_tokens: usize,
prefill_queue: VecDeque<InferenceRequest>,
decode_queue: VecDeque<RequestId>,
preempted_queue: VecDeque<InferenceRequest>,
}
impl BatchState {
fn new() -> Self {
Self {
active_slots: Vec::new(),
total_tokens: 0,
prefill_queue: VecDeque::new(),
decode_queue: VecDeque::new(),
preempted_queue: VecDeque::new(),
}
}
}
#[derive(Debug)]
pub struct ContinuousBatcher {
config: BatchConfig,
state: BatchState,
next_slot_id: usize,
completed_count: u64,
}
impl ContinuousBatcher {
pub fn new(config: BatchConfig) -> Self {
Self {
config,
state: BatchState::new(),
next_slot_id: 0,
completed_count: 0,
}
}
pub fn add_request(&mut self, request: InferenceRequest) -> DnnResult<RequestId> {
if request.sequence_length == 0 {
return Err(DnnError::InvalidArgument(
"sequence_length must be > 0".into(),
));
}
if request.sequence_length > self.config.max_sequence_length {
return Err(DnnError::InvalidArgument(format!(
"sequence_length {} exceeds max_sequence_length {}",
request.sequence_length, self.config.max_sequence_length
)));
}
let id = request.request_id;
self.state.prefill_queue.push_back(request);
Ok(id)
}
pub fn step(&mut self) -> DnnResult<BatchDecision> {
let mut decision = BatchDecision {
prefill_requests: Vec::new(),
decode_requests: Vec::new(),
preempted: Vec::new(),
total_tokens: 0,
};
let decode_ids: Vec<RequestId> = self
.state
.active_slots
.iter()
.filter(|s| s.is_active && !s.is_prefill)
.map(|s| s.request_id)
.collect();
let decode_count = decode_ids.len().min(self.config.decode_batch_size);
let decode_tokens: usize = self
.state
.active_slots
.iter()
.filter(|s| s.is_active && !s.is_prefill)
.take(decode_count)
.map(|s| s.current_seq_len + 1) .sum();
decision.decode_requests = decode_ids.into_iter().take(decode_count).collect();
self.sort_prefill_queue();
let mut prefill_budget = self
.config
.prefill_batch_size
.min(self.config.max_total_tokens.saturating_sub(decode_tokens));
let mut admitted = Vec::new();
while !self.state.prefill_queue.is_empty()
&& self.state.active_slots.len() + admitted.len() < self.config.max_batch_size
{
let req = match self.state.prefill_queue.front() {
Some(r) => r,
None => break,
};
if req.sequence_length > prefill_budget {
break;
}
let req = self
.state
.prefill_queue
.pop_front()
.ok_or_else(|| DnnError::InvalidArgument("empty queue".into()))?;
prefill_budget = prefill_budget.saturating_sub(req.sequence_length);
let slot = BatchSlot {
slot_id: self.next_slot_id,
request_id: req.request_id,
current_seq_len: req.sequence_length,
max_seq_len: req.sequence_length + req.max_new_tokens,
is_prefill: true,
is_active: true,
};
self.next_slot_id += 1;
decision.prefill_requests.push(req.request_id);
admitted.push(slot);
}
for slot in &mut admitted {
slot.is_prefill = false;
}
self.state.active_slots.extend(admitted);
for slot in &mut self.state.active_slots {
if slot.is_active && !slot.is_prefill {
slot.current_seq_len = slot.current_seq_len.saturating_add(1);
}
}
decision.total_tokens = self
.state
.active_slots
.iter()
.filter(|s| s.is_active)
.map(|s| s.current_seq_len)
.sum();
self.state.total_tokens = decision.total_tokens;
Ok(decision)
}
pub fn complete_request(&mut self, request_id: RequestId) -> DnnResult<()> {
let pos = self
.state
.active_slots
.iter()
.position(|s| s.request_id == request_id)
.ok_or_else(|| {
DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
})?;
let slot = &self.state.active_slots[pos];
self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
self.state.active_slots.remove(pos);
self.state.decode_queue.retain(|id| *id != request_id);
self.completed_count += 1;
Ok(())
}
pub fn preempt(&mut self, request_id: RequestId) -> DnnResult<()> {
let pos = self
.state
.active_slots
.iter()
.position(|s| s.request_id == request_id)
.ok_or_else(|| {
DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
})?;
let slot = self.state.active_slots.remove(pos);
self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
self.state.decode_queue.retain(|id| *id != request_id);
let preempted_req = InferenceRequest {
request_id,
sequence_length: slot.current_seq_len,
max_new_tokens: slot.max_seq_len.saturating_sub(slot.current_seq_len),
priority: Priority::Normal,
arrival_time_ns: 0,
deadline_ns: None,
};
self.state.preempted_queue.push_back(preempted_req);
Ok(())
}
pub fn active_requests(&self) -> usize {
self.state
.active_slots
.iter()
.filter(|s| s.is_active)
.count()
}
pub fn pending_requests(&self) -> usize {
self.state.prefill_queue.len() + self.state.preempted_queue.len()
}
pub fn throughput_tokens_per_step(&self) -> usize {
self.state.total_tokens
}
fn sort_prefill_queue(&mut self) {
let queue = &mut self.state.prefill_queue;
let policy = self.config.scheduling_policy;
let mut vec: Vec<InferenceRequest> = queue.drain(..).collect();
match policy {
SchedulingPolicy::Fcfs => {
vec.sort_by_key(|r| r.arrival_time_ns);
}
SchedulingPolicy::ShortestJobFirst => {
vec.sort_by_key(|r| r.max_new_tokens);
}
SchedulingPolicy::PriorityBased => {
vec.sort_by(|a, b| {
b.priority
.cmp(&a.priority)
.then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
});
}
SchedulingPolicy::DeadlineAware => {
vec.sort_by(|a, b| {
let da = a.deadline_ns.unwrap_or(u64::MAX);
let db = b.deadline_ns.unwrap_or(u64::MAX);
da.cmp(&db).then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
});
}
SchedulingPolicy::Orca => {
vec.sort_by_key(|r| r.arrival_time_ns);
}
}
*queue = VecDeque::from(vec);
}
}
#[derive(Debug)]
pub struct TokenBudgetAllocator {
max_total_tokens: usize,
allocated: usize,
}
impl TokenBudgetAllocator {
pub fn new(max_total_tokens: usize) -> Self {
Self {
max_total_tokens,
allocated: 0,
}
}
pub fn allocate_prefill(&mut self, seq_len: usize) -> Option<usize> {
if self.allocated + seq_len > self.max_total_tokens {
return None;
}
let slot = self.allocated;
self.allocated += seq_len;
Some(slot)
}
pub fn allocate_decode(&mut self, count: usize) -> usize {
let remaining = self.max_total_tokens.saturating_sub(self.allocated);
let actual = count.min(remaining);
self.allocated += actual;
actual
}
pub fn release(&mut self, tokens: usize) {
self.allocated = self.allocated.saturating_sub(tokens);
}
pub fn utilization(&self) -> f64 {
if self.max_total_tokens == 0 {
return 0.0;
}
self.allocated as f64 / self.max_total_tokens as f64
}
}
#[derive(Debug)]
pub struct PagedKvManager {
num_blocks: usize,
block_size: usize,
free_map: Vec<bool>,
ref_counts: Vec<usize>,
}
impl PagedKvManager {
pub fn new(num_blocks: usize, block_size: usize) -> Self {
Self {
num_blocks,
block_size,
free_map: vec![true; num_blocks],
ref_counts: vec![0; num_blocks],
}
}
pub fn allocate(&mut self, num_tokens: usize) -> DnnResult<Vec<usize>> {
if self.block_size == 0 {
return Err(DnnError::InvalidArgument("block_size is 0".into()));
}
let blocks_needed = num_tokens.div_ceil(self.block_size);
if !self.can_allocate(num_tokens) {
return Err(DnnError::InvalidArgument(format!(
"not enough free blocks: need {blocks_needed}, have {}",
self.free_block_count()
)));
}
let mut ids = Vec::with_capacity(blocks_needed);
for (i, free) in self.free_map.iter_mut().enumerate() {
if ids.len() >= blocks_needed {
break;
}
if *free {
*free = false;
self.ref_counts[i] = 1;
ids.push(i);
}
}
Ok(ids)
}
pub fn free(&mut self, block_ids: &[usize]) {
for &id in block_ids {
if id < self.num_blocks {
self.ref_counts[id] = self.ref_counts[id].saturating_sub(1);
if self.ref_counts[id] == 0 {
self.free_map[id] = true;
}
}
}
}
pub fn copy_on_write(&mut self, block_id: usize) -> DnnResult<usize> {
if block_id >= self.num_blocks {
return Err(DnnError::InvalidArgument(format!(
"block_id {block_id} out of range (max {})",
self.num_blocks
)));
}
let new_id =
self.free_map.iter().position(|&free| free).ok_or_else(|| {
DnnError::InvalidArgument("no free blocks for copy-on-write".into())
})?;
self.free_map[new_id] = false;
self.ref_counts[new_id] = 1;
self.ref_counts[block_id] = self.ref_counts[block_id].saturating_sub(1);
if self.ref_counts[block_id] == 0 {
self.free_map[block_id] = true;
}
Ok(new_id)
}
pub fn usage(&self) -> (usize, usize) {
let used = self.free_map.iter().filter(|&&free| !free).count();
(used, self.num_blocks)
}
pub fn can_allocate(&self, num_tokens: usize) -> bool {
if self.block_size == 0 {
return false;
}
let needed = num_tokens.div_ceil(self.block_size);
self.free_block_count() >= needed
}
fn free_block_count(&self) -> usize {
self.free_map.iter().filter(|&&f| f).count()
}
}
#[derive(Debug)]
pub struct SpeculativeDecoder {
draft_length: usize,
total_proposed: u64,
total_accepted: u64,
}
impl SpeculativeDecoder {
pub fn new(draft_length: usize) -> Self {
Self {
draft_length,
total_proposed: 0,
total_accepted: 0,
}
}
pub fn propose_tokens(&self, num_candidates: usize) -> Vec<Vec<u32>> {
(0..num_candidates)
.map(|c| {
(0..self.draft_length)
.map(|t| ((c * self.draft_length + t) % 32000) as u32)
.collect()
})
.collect()
}
pub fn verify_and_accept(
&mut self,
proposed: &[Vec<u32>],
target_probs: &[f64],
) -> (Vec<u32>, usize) {
if proposed.is_empty() {
return (Vec::new(), 0);
}
let best = proposed.first().cloned().unwrap_or_default();
let mut accepted = Vec::new();
let threshold = 0.5;
for (i, token) in best.iter().enumerate() {
let prob = target_probs.get(i).copied().unwrap_or(0.0);
if prob >= threshold {
accepted.push(*token);
} else {
break;
}
}
let count = accepted.len();
self.total_proposed += best.len() as u64;
self.total_accepted += count as u64;
(accepted, count)
}
pub fn acceptance_rate(&self) -> f64 {
if self.total_proposed == 0 {
return 0.0;
}
self.total_accepted as f64 / self.total_proposed as f64
}
}
#[derive(Debug)]
pub struct BatchMetrics {
steps: Vec<(usize, usize, u64)>,
ttft_samples: Vec<u64>,
}
impl BatchMetrics {
pub fn new() -> Self {
Self {
steps: Vec::new(),
ttft_samples: Vec::new(),
}
}
pub fn record_step(&mut self, prefill_tokens: usize, decode_tokens: usize, latency_us: u64) {
self.steps.push((prefill_tokens, decode_tokens, latency_us));
}
pub fn record_ttft(&mut self, ttft_us: u64) {
self.ttft_samples.push(ttft_us);
}
pub fn avg_prefill_latency(&self) -> f64 {
let prefills: Vec<u64> = self
.steps
.iter()
.filter(|(p, _, _)| *p > 0)
.map(|(_, _, l)| *l)
.collect();
if prefills.is_empty() {
return 0.0;
}
prefills.iter().sum::<u64>() as f64 / prefills.len() as f64
}
pub fn avg_decode_latency(&self) -> f64 {
let decodes: Vec<u64> = self
.steps
.iter()
.filter(|(_, d, _)| *d > 0)
.map(|(_, _, l)| *l)
.collect();
if decodes.is_empty() {
return 0.0;
}
decodes.iter().sum::<u64>() as f64 / decodes.len() as f64
}
pub fn avg_batch_size(&self) -> f64 {
if self.steps.is_empty() {
return 0.0;
}
let total: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
total as f64 / self.steps.len() as f64
}
pub fn token_throughput(&self) -> f64 {
if self.steps.is_empty() {
return 0.0;
}
let total_tokens: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
let total_us: u64 = self.steps.iter().map(|(_, _, l)| l).sum();
if total_us == 0 {
return 0.0;
}
total_tokens as f64 / (total_us as f64 / 1_000_000.0)
}
pub fn time_to_first_token_p50(&self) -> f64 {
if self.ttft_samples.is_empty() {
return 0.0;
}
let mut sorted = self.ttft_samples.clone();
sorted.sort_unstable();
let mid = sorted.len() / 2;
if sorted.len() % 2 == 0 && sorted.len() >= 2 {
(sorted[mid - 1] + sorted[mid]) as f64 / 2.0
} else {
sorted[mid] as f64
}
}
pub fn format_report(&self) -> String {
format!(
"BatchMetrics Report\n\
====================\n\
Steps recorded : {}\n\
Avg prefill latency : {:.1} us\n\
Avg decode latency : {:.1} us\n\
Avg batch size : {:.1} tokens/step\n\
Token throughput : {:.0} tokens/s\n\
TTFT p50 : {:.1} us\n\
TTFT samples : {}",
self.steps.len(),
self.avg_prefill_latency(),
self.avg_decode_latency(),
self.avg_batch_size(),
self.token_throughput(),
self.time_to_first_token_p50(),
self.ttft_samples.len(),
)
}
}
impl Default for BatchMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> BatchConfig {
BatchConfig {
max_batch_size: 8,
max_total_tokens: 4096,
max_sequence_length: 2048,
prefill_batch_size: 1024,
decode_batch_size: 8,
scheduling_policy: SchedulingPolicy::Fcfs,
}
}
fn make_request(id: RequestId, seq_len: usize, max_new: usize) -> InferenceRequest {
InferenceRequest {
request_id: id,
sequence_length: seq_len,
max_new_tokens: max_new,
priority: Priority::Normal,
arrival_time_ns: id * 1000,
deadline_ns: None,
}
}
#[test]
fn test_add_single_request() {
let mut batcher = ContinuousBatcher::new(default_config());
let req = make_request(1, 128, 64);
let id = batcher.add_request(req).expect("should succeed");
assert_eq!(id, 1);
assert_eq!(batcher.pending_requests(), 1);
assert_eq!(batcher.active_requests(), 0);
}
#[test]
fn test_batch_step_mixed_prefill_decode() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(1, 64, 32)).expect("add 1");
let d1 = batcher.step().expect("step 1");
assert_eq!(d1.prefill_requests.len(), 1);
batcher.add_request(make_request(2, 32, 16)).expect("add 2");
let d2 = batcher.step().expect("step 2");
assert!(!d2.decode_requests.is_empty(), "should have decode slots");
assert!(!d2.prefill_requests.is_empty(), "should have prefill slots");
}
#[test]
fn test_token_budget_allocation_release() {
let mut alloc = TokenBudgetAllocator::new(1024);
let slot = alloc.allocate_prefill(512);
assert!(slot.is_some());
assert!((alloc.utilization() - 0.5).abs() < 1e-9);
assert!(alloc.allocate_prefill(600).is_none());
alloc.release(256);
assert!((alloc.utilization() - 0.25).abs() < 1e-9);
}
#[test]
fn test_paged_kv_allocation_free() {
let mut mgr = PagedKvManager::new(16, 64);
let blocks = mgr.allocate(128).expect("allocate 128");
assert_eq!(blocks.len(), 2);
let (used, total) = mgr.usage();
assert_eq!(used, 2);
assert_eq!(total, 16);
mgr.free(&blocks);
let (used, _) = mgr.usage();
assert_eq!(used, 0);
}
#[test]
fn test_copy_on_write() {
let mut mgr = PagedKvManager::new(4, 64);
let blocks = mgr.allocate(64).expect("allocate");
assert_eq!(blocks.len(), 1);
let orig = blocks[0];
mgr.ref_counts[orig] = 2;
let new_id = mgr.copy_on_write(orig).expect("cow");
assert_ne!(new_id, orig);
assert!(!mgr.free_map[orig]);
assert_eq!(mgr.ref_counts[orig], 1);
assert_eq!(mgr.ref_counts[new_id], 1);
}
#[test]
fn test_continuous_batching_completion() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(10, 64, 8)).expect("add");
let _ = batcher.step().expect("step");
assert_eq!(batcher.active_requests(), 1);
batcher.complete_request(10).expect("complete");
assert_eq!(batcher.active_requests(), 0);
}
#[test]
fn test_preemption() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(20, 64, 16)).expect("add");
let _ = batcher.step().expect("step");
assert_eq!(batcher.active_requests(), 1);
batcher.preempt(20).expect("preempt");
assert_eq!(batcher.active_requests(), 0);
assert_eq!(batcher.pending_requests(), 1);
}
#[test]
fn test_fcfs_scheduling_order() {
let mut batcher = ContinuousBatcher::new(default_config());
batcher.add_request(make_request(3, 32, 8)).expect("add 3");
batcher.add_request(make_request(1, 32, 8)).expect("add 1");
batcher.add_request(make_request(2, 32, 8)).expect("add 2");
let d = batcher.step().expect("step");
assert_eq!(d.prefill_requests, vec![1, 2, 3]);
}
#[test]
fn test_priority_based_scheduling() {
let mut config = default_config();
config.scheduling_policy = SchedulingPolicy::PriorityBased;
let mut batcher = ContinuousBatcher::new(config);
let mut low = make_request(1, 32, 8);
low.priority = Priority::Low;
low.arrival_time_ns = 100;
let mut high = make_request(2, 32, 8);
high.priority = Priority::High;
high.arrival_time_ns = 200;
let mut normal = make_request(3, 32, 8);
normal.priority = Priority::Normal;
normal.arrival_time_ns = 50;
batcher.add_request(low).expect("add low");
batcher.add_request(high).expect("add high");
batcher.add_request(normal).expect("add normal");
let d = batcher.step().expect("step");
assert_eq!(d.prefill_requests, vec![2, 3, 1]);
}
#[test]
fn test_deadline_aware_scheduling() {
let mut config = default_config();
config.scheduling_policy = SchedulingPolicy::DeadlineAware;
let mut batcher = ContinuousBatcher::new(config);
let mut r1 = make_request(1, 32, 8);
r1.deadline_ns = Some(5000);
let mut r2 = make_request(2, 32, 8);
r2.deadline_ns = Some(1000);
let mut r3 = make_request(3, 32, 8);
r3.deadline_ns = None;
batcher.add_request(r1).expect("add r1");
batcher.add_request(r2).expect("add r2");
batcher.add_request(r3).expect("add r3");
let d = batcher.step().expect("step");
assert_eq!(d.prefill_requests, vec![2, 1, 3]);
}
#[test]
fn test_speculative_decoding_acceptance() {
let mut spec = SpeculativeDecoder::new(4);
let proposed = spec.propose_tokens(2);
assert_eq!(proposed.len(), 2);
assert_eq!(proposed[0].len(), 4);
let probs = vec![0.8, 0.6, 0.3, 0.1];
let (accepted, count) = spec.verify_and_accept(&proposed, &probs);
assert_eq!(count, 2);
assert_eq!(accepted.len(), 2);
assert!(spec.acceptance_rate() > 0.0);
}
#[test]
fn test_batch_metrics_tracking() {
let mut m = BatchMetrics::new();
m.record_step(128, 0, 500);
m.record_step(0, 8, 100);
m.record_step(64, 4, 300);
assert!((m.avg_prefill_latency() - 400.0).abs() < 1e-9);
assert!((m.avg_decode_latency() - 200.0).abs() < 1e-9);
assert!((m.avg_batch_size() - 68.0).abs() < 1e-9);
assert!(m.token_throughput() > 0.0);
}
#[test]
fn test_max_batch_size_enforcement() {
let mut config = default_config();
config.max_batch_size = 2;
let mut batcher = ContinuousBatcher::new(config);
for i in 0..4 {
batcher.add_request(make_request(i, 32, 8)).expect("add");
}
let d = batcher.step().expect("step");
assert!(d.prefill_requests.len() <= 2);
assert_eq!(batcher.active_requests(), d.prefill_requests.len());
}
#[test]
fn test_queue_management() {
let mut batcher = ContinuousBatcher::new(default_config());
assert_eq!(batcher.pending_requests(), 0);
batcher.add_request(make_request(1, 32, 8)).expect("add");
batcher.add_request(make_request(2, 32, 8)).expect("add");
assert_eq!(batcher.pending_requests(), 2);
let _ = batcher.step().expect("step");
assert_eq!(batcher.pending_requests(), 0);
assert_eq!(batcher.active_requests(), 2);
batcher.complete_request(1).expect("complete");
assert_eq!(batcher.active_requests(), 1);
}
#[test]
fn test_utilization_calculation() {
let mut alloc = TokenBudgetAllocator::new(1000);
assert!((alloc.utilization() - 0.0).abs() < 1e-9);
alloc.allocate_prefill(250);
assert!((alloc.utilization() - 0.25).abs() < 1e-9);
let fitted = alloc.allocate_decode(900);
assert_eq!(fitted, 750);
assert!((alloc.utilization() - 1.0).abs() < 1e-9);
let zero = TokenBudgetAllocator::new(0);
assert!((zero.utilization() - 0.0).abs() < 1e-9);
}
#[test]
fn test_format_report() {
let mut m = BatchMetrics::new();
m.record_step(100, 10, 200);
m.record_step(0, 8, 100);
m.record_ttft(150);
m.record_ttft(250);
let report = m.format_report();
assert!(report.contains("Steps recorded"));
assert!(report.contains("Token throughput"));
assert!(report.contains("TTFT p50"));
}
}