#![allow(clippy::must_use_candidate)]
#![allow(clippy::return_self_not_must_use)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::unnecessary_wraps)] #![allow(clippy::derivable_impls)] #![allow(clippy::option_if_let_else)]
use crate::paged_kv::{PagedCacheError, PagedKvCache, SeqId};
use serde::{Deserialize, Serialize};
use std::collections::{BinaryHeap, HashMap, VecDeque};
use std::time::Instant;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SchedulerError {
#[error("Request queue full: capacity {capacity}")]
QueueFull {
capacity: usize,
},
#[error("Request not found: {0}")]
RequestNotFound(u64),
#[error("KV cache error: {0}")]
CacheError(#[from] PagedCacheError),
#[error("Invalid scheduler state: {0}")]
InvalidState(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
impl Default for Priority {
fn default() -> Self {
Self::Normal
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SequenceState {
Waiting,
Running,
Preempted,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct SchedulerRequest {
pub request_id: u64,
pub input_ids: Vec<u32>,
pub max_tokens: usize,
pub priority: Priority,
pub arrival_time: Instant,
pub seq_id: Option<SeqId>,
pub state: SequenceState,
pub generated_tokens: Vec<u32>,
pub iterations: usize,
}
impl SchedulerRequest {
pub fn new(request_id: u64, input_ids: Vec<u32>, max_tokens: usize) -> Self {
Self {
request_id,
input_ids,
max_tokens,
priority: Priority::default(),
arrival_time: Instant::now(),
seq_id: None,
state: SequenceState::Waiting,
generated_tokens: Vec::new(),
iterations: 0,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn total_tokens(&self) -> usize {
self.input_ids.len() + self.generated_tokens.len()
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.generated_tokens.len())
}
pub fn is_complete(&self, eos_token: u32) -> bool {
self.generated_tokens.len() >= self.max_tokens
|| self.generated_tokens.last() == Some(&eos_token)
}
pub fn wait_time(&self) -> std::time::Duration {
self.arrival_time.elapsed()
}
}
#[derive(Debug)]
struct PriorityEntry {
priority: Priority,
arrival_time: Instant,
request_id: u64,
}
impl PartialEq for PriorityEntry {
fn eq(&self, other: &Self) -> bool {
self.request_id == other.request_id
}
}
impl Eq for PriorityEntry {}
impl PartialOrd for PriorityEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.priority.cmp(&other.priority) {
std::cmp::Ordering::Equal => other.arrival_time.cmp(&self.arrival_time),
other => other,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SchedulerOutput {
pub scheduled_seq_ids: Vec<SeqId>,
pub scheduled_request_ids: Vec<u64>,
pub preempted_seq_ids: Vec<SeqId>,
pub completed_request_ids: Vec<u64>,
pub num_prefill_tokens: usize,
pub num_decode_tokens: usize,
}
impl SchedulerOutput {
pub fn total_tokens(&self) -> usize {
self.num_prefill_tokens + self.num_decode_tokens
}
pub fn is_empty(&self) -> bool {
self.scheduled_seq_ids.is_empty()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SchedulerStats {
pub total_requests: u64,
pub completed_requests: u64,
pub preemptions: u64,
pub avg_wait_time_ms: f64,
pub avg_ttft_ms: f64,
pub queue_depth: usize,
pub running_count: usize,
}
pub struct Scheduler {
requests: HashMap<u64, SchedulerRequest>,
waiting_queue: BinaryHeap<PriorityEntry>,
running: Vec<u64>,
preempted: VecDeque<u64>,
max_batch_size: usize,
max_queue_size: usize,
max_tokens_per_batch: usize,
next_request_id: u64,
stats: SchedulerStats,
total_wait_time_ms: f64,
}
impl Scheduler {
pub fn new(max_batch_size: usize, max_queue_size: usize) -> Self {
Self {
requests: HashMap::new(),
waiting_queue: BinaryHeap::new(),
running: Vec::new(),
preempted: VecDeque::new(),
max_batch_size,
max_queue_size,
max_tokens_per_batch: max_batch_size * 2048, next_request_id: 0,
stats: SchedulerStats::default(),
total_wait_time_ms: 0.0,
}
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens_per_batch = max_tokens;
self
}
pub fn add_request(
&mut self,
input_ids: Vec<u32>,
max_tokens: usize,
) -> Result<u64, SchedulerError> {
self.add_request_with_priority(input_ids, max_tokens, Priority::Normal)
}
pub fn add_request_with_priority(
&mut self,
input_ids: Vec<u32>,
max_tokens: usize,
priority: Priority,
) -> Result<u64, SchedulerError> {
if self.waiting_queue.len() >= self.max_queue_size {
return Err(SchedulerError::QueueFull {
capacity: self.max_queue_size,
});
}
let request_id = self.next_request_id;
self.next_request_id += 1;
let request =
SchedulerRequest::new(request_id, input_ids, max_tokens).with_priority(priority);
let entry = PriorityEntry {
priority,
arrival_time: request.arrival_time,
request_id,
};
self.requests.insert(request_id, request);
self.waiting_queue.push(entry);
self.stats.total_requests += 1;
self.stats.queue_depth = self.waiting_queue.len();
Ok(request_id)
}
pub fn schedule(
&mut self,
kv_cache: &mut PagedKvCache,
eos_token: u32,
) -> Result<SchedulerOutput, SchedulerError> {
let mut output = SchedulerOutput::default();
self.check_completions(&mut output, eos_token);
self.handle_preemption(&mut output, kv_cache);
self.resume_preempted(&mut output, kv_cache)?;
self.schedule_waiting(&mut output, kv_cache)?;
for &request_id in &self.running {
if let Some(request) = self.requests.get(&request_id) {
if let Some(seq_id) = request.seq_id {
output.scheduled_seq_ids.push(seq_id);
output.scheduled_request_ids.push(request_id);
if request.iterations == 0 {
output.num_prefill_tokens += request.input_ids.len();
} else {
output.num_decode_tokens += 1;
}
}
}
}
self.stats.running_count = self.running.len();
self.stats.queue_depth = self.waiting_queue.len();
Ok(output)
}
pub fn update_after_iteration(&mut self, generated_tokens: &HashMap<u64, u32>) {
for (&request_id, &token) in generated_tokens {
if let Some(request) = self.requests.get_mut(&request_id) {
request.generated_tokens.push(token);
request.iterations += 1;
}
}
}
pub fn complete_request(&mut self, request_id: u64, kv_cache: &mut PagedKvCache) {
if let Some(request) = self.requests.get_mut(&request_id) {
request.state = SequenceState::Completed;
if let Some(seq_id) = request.seq_id {
kv_cache.free_sequence(seq_id);
}
self.stats.completed_requests += 1;
let wait_time = request.wait_time().as_secs_f64() * 1000.0;
self.total_wait_time_ms += wait_time;
self.stats.avg_wait_time_ms =
self.total_wait_time_ms / self.stats.completed_requests as f64;
}
self.running.retain(|&id| id != request_id);
}
pub fn get_request(&self, request_id: u64) -> Option<&SchedulerRequest> {
self.requests.get(&request_id)
}
pub fn stats(&self) -> &SchedulerStats {
&self.stats
}
fn check_completions(&mut self, output: &mut SchedulerOutput, eos_token: u32) {
let completed: Vec<u64> = self
.running
.iter()
.filter(|&&id| {
self.requests
.get(&id)
.is_some_and(|r| r.is_complete(eos_token))
})
.copied()
.collect();
for request_id in completed {
if let Some(request) = self.requests.get_mut(&request_id) {
request.state = SequenceState::Completed;
}
output.completed_request_ids.push(request_id);
}
}
fn handle_preemption(&mut self, output: &mut SchedulerOutput, kv_cache: &mut PagedKvCache) {
if self.running.len() >= self.max_batch_size && !self.waiting_queue.is_empty() {
if let Some(waiting_entry) = self.waiting_queue.peek() {
let min_running_priority = self
.running
.iter()
.filter_map(|&id| self.requests.get(&id))
.map(|r| r.priority)
.min()
.unwrap_or(Priority::Critical);
if waiting_entry.priority > min_running_priority {
if let Some(&preempt_id) = self.running.iter().find(|&&id| {
self.requests
.get(&id)
.is_some_and(|r| r.priority == min_running_priority)
}) {
if let Some(request) = self.requests.get_mut(&preempt_id) {
request.state = SequenceState::Preempted;
if let Some(seq_id) = request.seq_id {
output.preempted_seq_ids.push(seq_id);
kv_cache.free_sequence(seq_id);
}
request.seq_id = None;
}
self.running.retain(|&id| id != preempt_id);
self.preempted.push_back(preempt_id);
self.stats.preemptions += 1;
}
}
}
}
}
fn resume_preempted(
&mut self,
_output: &mut SchedulerOutput,
kv_cache: &mut PagedKvCache,
) -> Result<(), SchedulerError> {
while self.running.len() < self.max_batch_size {
if let Some(request_id) = self.preempted.pop_front() {
if let Some(request) = self.requests.get_mut(&request_id) {
let total_tokens = request.total_tokens();
match kv_cache.allocate_sequence(total_tokens) {
Ok(seq_id) => {
request.seq_id = Some(seq_id);
request.state = SequenceState::Running;
self.running.push(request_id);
},
Err(_) => {
self.preempted.push_front(request_id);
break;
},
}
}
} else {
break;
}
}
Ok(())
}
fn schedule_waiting(
&mut self,
_output: &mut SchedulerOutput,
kv_cache: &mut PagedKvCache,
) -> Result<(), SchedulerError> {
while self.running.len() < self.max_batch_size {
if let Some(entry) = self.waiting_queue.pop() {
if let Some(request) = self.requests.get_mut(&entry.request_id) {
let total_tokens = request.input_ids.len();
match kv_cache.allocate_sequence(total_tokens) {
Ok(seq_id) => {
request.seq_id = Some(seq_id);
request.state = SequenceState::Running;
self.running.push(entry.request_id);
},
Err(_) => {
self.waiting_queue.push(entry);
break;
},
}
}
} else {
break;
}
}
Ok(())
}
pub fn waiting_count(&self) -> usize {
self.waiting_queue.len()
}
pub fn running_count(&self) -> usize {
self.running.len()
}
pub fn preempted_count(&self) -> usize {
self.preempted.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SlotState {
Idle,
Processing,
Generating,
}
impl Default for SlotState {
fn default() -> Self {
Self::Idle
}
}
#[derive(Debug, Clone)]
pub struct Slot {
pub id: usize,
pub state: SlotState,
pub request_id: Option<u64>,
pub seq_id: Option<SeqId>,
pub input_tokens: Vec<u32>,
pub generated_tokens: Vec<u32>,
pub max_tokens: usize,
pub n_prompt_tokens_processed: usize,
pub generation_start: Option<Instant>,
pub prompt_time_ms: f64,
pub generation_time_ms: f64,
}
impl Slot {
pub fn new(id: usize) -> Self {
Self {
id,
state: SlotState::Idle,
request_id: None,
seq_id: None,
input_tokens: Vec::new(),
generated_tokens: Vec::new(),
max_tokens: 0,
n_prompt_tokens_processed: 0,
generation_start: None,
prompt_time_ms: 0.0,
generation_time_ms: 0.0,
}
}
pub fn is_idle(&self) -> bool {
self.state == SlotState::Idle
}
pub fn is_generating(&self) -> bool {
self.state == SlotState::Generating
}
pub fn assign(&mut self, request_id: u64, input_tokens: Vec<u32>, max_tokens: usize) {
self.state = SlotState::Processing;
self.request_id = Some(request_id);
self.input_tokens = input_tokens;
self.max_tokens = max_tokens;
self.generated_tokens.clear();
self.n_prompt_tokens_processed = 0;
self.prompt_time_ms = 0.0;
self.generation_time_ms = 0.0;
self.generation_start = None;
}
pub fn start_generation(&mut self, prompt_time_ms: f64) {
self.state = SlotState::Generating;
self.prompt_time_ms = prompt_time_ms;
self.generation_start = Some(Instant::now());
}
pub fn add_token(&mut self, token: u32) {
self.generated_tokens.push(token);
}
pub fn is_complete(&self, eos_token: u32) -> bool {
if self.generated_tokens.len() >= self.max_tokens {
return true;
}
if let Some(&last) = self.generated_tokens.last() {
if last == eos_token {
return true;
}
}
false
}
pub fn finish(&mut self) {
if let Some(start) = self.generation_start {
self.generation_time_ms = start.elapsed().as_secs_f64() * 1000.0;
}
self.state = SlotState::Idle;
self.request_id = None;
self.seq_id = None;
}
pub fn tokens_per_second(&self) -> f64 {
if self.generation_time_ms > 0.0 {
self.generated_tokens.len() as f64 / (self.generation_time_ms / 1000.0)
} else {
0.0
}
}
}
#[derive(Debug)]
pub struct SlotManager {
slots: Vec<Slot>,
pub max_context_length: usize,
next_request_id: u64,
}
impl SlotManager {
pub fn new(num_slots: usize, max_context_length: usize) -> Self {
let slots = (0..num_slots).map(Slot::new).collect();
Self {
slots,
max_context_length,
next_request_id: 0,
}
}
pub fn num_slots(&self) -> usize {
self.slots.len()
}
pub fn num_idle_slots(&self) -> usize {
self.slots.iter().filter(|s| s.is_idle()).count()
}
pub fn num_active_slots(&self) -> usize {
self.slots.len() - self.num_idle_slots()
}
pub fn find_idle_slot(&self) -> Option<usize> {
self.slots.iter().position(Slot::is_idle)
}
pub fn assign_request(
&mut self,
input_tokens: Vec<u32>,
max_tokens: usize,
) -> Option<(usize, u64)> {
let slot_id = self.find_idle_slot()?;
let request_id = self.next_request_id;
self.next_request_id += 1;
self.slots[slot_id].assign(request_id, input_tokens, max_tokens);
Some((slot_id, request_id))
}
pub fn get_slot(&self, slot_id: usize) -> Option<&Slot> {
self.slots.get(slot_id)
}
pub fn get_slot_mut(&mut self, slot_id: usize) -> Option<&mut Slot> {
self.slots.get_mut(slot_id)
}
pub fn active_slots(&self) -> impl Iterator<Item = &Slot> {
self.slots.iter().filter(|s| !s.is_idle())
}
pub fn generating_slots(&self) -> impl Iterator<Item = &Slot> {
self.slots.iter().filter(|s| s.is_generating())
}
pub fn batch_slots(&self) -> Vec<usize> {
self.slots
.iter()
.enumerate()
.filter(|(_, s)| s.is_generating())
.map(|(i, _)| i)
.collect()
}
pub fn utilization(&self) -> f64 {
if self.slots.is_empty() {
0.0
} else {
self.num_active_slots() as f64 / self.slots.len() as f64
}
}
pub fn aggregate_tokens_per_second(&self) -> f64 {
self.slots.iter().map(Slot::tokens_per_second).sum()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BatchType {
Prefill,
Decode,
Mixed,
}
impl Default for BatchType {
fn default() -> Self {
Self::Decode
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BatchToken {
pub token_id: u32,
pub seq_idx: usize,
pub seq_pos: usize,
pub is_prompt: bool,
}
impl BatchToken {
pub fn new(token_id: u32, seq_idx: usize, seq_pos: usize, is_prompt: bool) -> Self {
Self {
token_id,
seq_idx,
seq_pos,
is_prompt,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MicroBatch {
pub tokens: Vec<BatchToken>,
pub seq_indices: Vec<usize>,
pub batch_type: BatchType,
pub max_seq_len: usize,
pub n_prompt_tokens: usize,
pub n_decode_tokens: usize,
}
impl MicroBatch {
pub fn new() -> Self {
Self {
tokens: Vec::new(),
seq_indices: Vec::new(),
batch_type: BatchType::Decode,
max_seq_len: 0,
n_prompt_tokens: 0,
n_decode_tokens: 0,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
tokens: Vec::with_capacity(capacity),
seq_indices: Vec::new(),
batch_type: BatchType::Decode,
max_seq_len: 0,
n_prompt_tokens: 0,
n_decode_tokens: 0,
}
}
pub fn add_token(&mut self, token: BatchToken) {
if token.is_prompt {
self.n_prompt_tokens += 1;
} else {
self.n_decode_tokens += 1;
}
if !self.seq_indices.contains(&token.seq_idx) {
self.seq_indices.push(token.seq_idx);
}
self.max_seq_len = self.max_seq_len.max(token.seq_pos + 1);
self.tokens.push(token);
self.update_batch_type();
}
fn update_batch_type(&mut self) {
self.batch_type = match (self.n_prompt_tokens > 0, self.n_decode_tokens > 0) {
(true, false) => BatchType::Prefill,
(true, true) => BatchType::Mixed,
(false, _) => BatchType::Decode,
};
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
pub fn num_sequences(&self) -> usize {
self.seq_indices.len()
}
pub fn is_prefill(&self) -> bool {
self.batch_type == BatchType::Prefill
}
pub fn is_decode(&self) -> bool {
self.batch_type == BatchType::Decode
}
pub fn is_mixed(&self) -> bool {
self.batch_type == BatchType::Mixed
}
pub fn token_ids(&self) -> Vec<u32> {
self.tokens.iter().map(|t| t.token_id).collect()
}
pub fn positions(&self) -> Vec<usize> {
self.tokens.iter().map(|t| t.seq_pos).collect()
}
pub fn clear(&mut self) {
self.tokens.clear();
self.seq_indices.clear();
self.batch_type = BatchType::Decode;
self.max_seq_len = 0;
self.n_prompt_tokens = 0;
self.n_decode_tokens = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SequenceBatchEntry {
pub seq_idx: usize,
pub slot_id: usize,
pub request_id: u64,
pub position: usize,
pub tokens: Vec<u32>,
pub is_prefill: bool,
}
impl SequenceBatchEntry {
pub fn new(seq_idx: usize, slot_id: usize, request_id: u64) -> Self {
Self {
seq_idx,
slot_id,
request_id,
position: 0,
tokens: Vec::new(),
is_prefill: true,
}
}
pub fn with_tokens(mut self, tokens: Vec<u32>) -> Self {
self.tokens = tokens;
self
}
pub fn at_position(mut self, position: usize) -> Self {
self.position = position;
self
}
pub fn decoding(mut self) -> Self {
self.is_prefill = false;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SequenceBatch {
pub sequences: Vec<SequenceBatchEntry>,
pub max_batch_size: usize,
pub utilization: f64,
}
impl SequenceBatch {
pub fn new(max_batch_size: usize) -> Self {
Self {
sequences: Vec::with_capacity(max_batch_size),
max_batch_size,
utilization: 0.0,
}
}
pub fn add_sequence(&mut self, entry: SequenceBatchEntry) -> bool {
if self.sequences.len() >= self.max_batch_size {
return false;
}
self.sequences.push(entry);
self.update_utilization();
true
}
pub fn remove_sequence(&mut self, seq_idx: usize) -> Option<SequenceBatchEntry> {
let pos = self.sequences.iter().position(|s| s.seq_idx == seq_idx)?;
let entry = self.sequences.remove(pos);
self.update_utilization();
Some(entry)
}
fn update_utilization(&mut self) {
self.utilization = if self.max_batch_size > 0 {
self.sequences.len() as f64 / self.max_batch_size as f64
} else {
0.0
};
}
pub fn len(&self) -> usize {
self.sequences.len()
}
pub fn is_empty(&self) -> bool {
self.sequences.is_empty()
}
pub fn is_full(&self) -> bool {
self.sequences.len() >= self.max_batch_size
}
pub fn prefill_sequences(&self) -> impl Iterator<Item = &SequenceBatchEntry> {
self.sequences.iter().filter(|s| s.is_prefill)
}
pub fn decode_sequences(&self) -> impl Iterator<Item = &SequenceBatchEntry> {
self.sequences.iter().filter(|s| !s.is_prefill)
}
pub fn num_prefill(&self) -> usize {
self.sequences.iter().filter(|s| s.is_prefill).count()
}
pub fn num_decode(&self) -> usize {
self.sequences.iter().filter(|s| !s.is_prefill).count()
}
pub fn clear(&mut self) {
self.sequences.clear();
self.utilization = 0.0;
}
pub fn get(&self, seq_idx: usize) -> Option<&SequenceBatchEntry> {
self.sequences.iter().find(|s| s.seq_idx == seq_idx)
}
pub fn get_mut(&mut self, seq_idx: usize) -> Option<&mut SequenceBatchEntry> {
self.sequences.iter_mut().find(|s| s.seq_idx == seq_idx)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_ubatch_tokens: usize,
pub max_sbatch_sequences: usize,
pub prefer_pure_decode: bool,
pub max_context_length: usize,
pub dynamic_batching: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_ubatch_tokens: 512,
max_sbatch_sequences: 8,
prefer_pure_decode: true,
max_context_length: 2048,
dynamic_batching: true,
}
}
}
impl BatchConfig {
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_ubatch_tokens = max_tokens;
self
}
pub fn with_max_sequences(mut self, max_seqs: usize) -> Self {
self.max_sbatch_sequences = max_seqs;
self
}
}
pub struct BatchScheduler {
config: BatchConfig,
sbatch: SequenceBatch,
next_seq_idx: usize,
stats: BatchStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatchStats {
pub ubatches_created: u64,
pub sbatches_created: u64,
pub tokens_processed: u64,
pub prefill_tokens: u64,
pub decode_tokens: u64,
pub avg_ubatch_size: f64,
pub avg_sbatch_size: f64,
}
impl BatchScheduler {
pub fn new() -> Self {
Self::with_config(BatchConfig::default())
}
pub fn with_config(config: BatchConfig) -> Self {
let max_seqs = config.max_sbatch_sequences;
Self {
config,
sbatch: SequenceBatch::new(max_seqs),
next_seq_idx: 0,
stats: BatchStats::default(),
}
}
pub fn add_sequence(
&mut self,
slot_id: usize,
request_id: u64,
input_tokens: Vec<u32>,
) -> Option<usize> {
if self.sbatch.is_full() {
return None;
}
let seq_idx = self.next_seq_idx;
self.next_seq_idx += 1;
let entry = SequenceBatchEntry::new(seq_idx, slot_id, request_id).with_tokens(input_tokens);
if self.sbatch.add_sequence(entry) {
Some(seq_idx)
} else {
None
}
}
pub fn complete_sequence(&mut self, seq_idx: usize) -> Option<SequenceBatchEntry> {
self.sbatch.remove_sequence(seq_idx)
}
pub fn start_decode(&mut self, seq_idx: usize, position: usize) -> bool {
if let Some(entry) = self.sbatch.get_mut(seq_idx) {
entry.is_prefill = false;
entry.position = position;
entry.tokens.clear(); true
} else {
false
}
}
pub fn create_ubatch(&mut self) -> MicroBatch {
let mut ubatch = MicroBatch::with_capacity(self.config.max_ubatch_tokens);
for entry in &self.sbatch.sequences {
if entry.is_prefill {
for (i, &token_id) in entry.tokens.iter().enumerate() {
if ubatch.len() >= self.config.max_ubatch_tokens {
break;
}
ubatch.add_token(BatchToken::new(token_id, entry.seq_idx, i, true));
}
}
}
if self.config.prefer_pure_decode && !ubatch.is_empty() && ubatch.is_prefill() {
self.record_ubatch(&ubatch);
return ubatch;
}
for entry in &self.sbatch.sequences {
if !entry.is_prefill {
if ubatch.len() >= self.config.max_ubatch_tokens {
break;
}
ubatch.add_token(BatchToken::new(
0, entry.seq_idx,
entry.position,
false,
));
}
}
self.record_ubatch(&ubatch);
ubatch
}
fn record_ubatch(&mut self, ubatch: &MicroBatch) {
if ubatch.is_empty() {
return;
}
self.stats.ubatches_created += 1;
self.stats.tokens_processed += ubatch.len() as u64;
self.stats.prefill_tokens += ubatch.n_prompt_tokens as u64;
self.stats.decode_tokens += ubatch.n_decode_tokens as u64;
let n = self.stats.ubatches_created as f64;
self.stats.avg_ubatch_size =
self.stats.avg_ubatch_size * (n - 1.0) / n + ubatch.len() as f64 / n;
}
pub fn sbatch(&self) -> &SequenceBatch {
&self.sbatch
}
pub fn stats(&self) -> &BatchStats {
&self.stats
}
pub fn config(&self) -> &BatchConfig {
&self.config
}
pub fn num_sequences(&self) -> usize {
self.sbatch.len()
}
pub fn has_capacity(&self) -> bool {
!self.sbatch.is_full()
}
pub fn utilization(&self) -> f64 {
self.sbatch.utilization
}
}
impl Default for BatchScheduler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Deadline {
pub target_latency_ms: u64,
pub hard_deadline_ms: Option<u64>,
pub sla_target: f64,
}
impl Default for Deadline {
fn default() -> Self {
Self {
target_latency_ms: 1000, hard_deadline_ms: None,
sla_target: 0.99, }
}
}
impl Deadline {
pub fn with_target(target_ms: u64) -> Self {
Self {
target_latency_ms: target_ms,
..Default::default()
}
}
pub fn strict(target_ms: u64, hard_ms: u64) -> Self {
Self {
target_latency_ms: target_ms,
hard_deadline_ms: Some(hard_ms),
sla_target: 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicPriorityConfig {
pub enable_age_promotion: bool,
pub promotion_interval_ms: u64,
pub max_promoted_priority: Priority,
pub priority_budgets: [f64; 4], pub enable_deadline_scheduling: bool,
pub urgency_factor: f64,
pub min_tokens_per_request: usize,
pub enable_fair_share: bool,
}
impl Default for DynamicPriorityConfig {
fn default() -> Self {
Self {
enable_age_promotion: true,
promotion_interval_ms: 5000, max_promoted_priority: Priority::High,
priority_budgets: [0.05, 0.30, 0.40, 0.25],
enable_deadline_scheduling: true,
urgency_factor: 2.0,
min_tokens_per_request: 1,
enable_fair_share: true,
}
}
}
impl DynamicPriorityConfig {
pub fn with_budgets(budgets: [f64; 4]) -> Self {
Self {
priority_budgets: budgets,
..Default::default()
}
}
pub fn no_promotion(mut self) -> Self {
self.enable_age_promotion = false;
self
}
pub fn with_promotion_interval(mut self, ms: u64) -> Self {
self.promotion_interval_ms = ms;
self
}
}
#[derive(Debug, Clone)]
pub struct DynamicRequest {
pub request_id: u64,
pub input_ids: Vec<u32>,
pub max_tokens: usize,
pub original_priority: Priority,
pub effective_priority: Priority,
pub arrival_time: Instant,
pub deadline: Option<Deadline>,
pub promotions: u32,
pub state: SequenceState,
pub generated_tokens: Vec<u32>,
pub seq_id: Option<SeqId>,
pub ttft_ms: Option<f64>,
}
impl DynamicRequest {
pub fn new(request_id: u64, input_ids: Vec<u32>, max_tokens: usize) -> Self {
Self {
request_id,
input_ids,
max_tokens,
original_priority: Priority::Normal,
effective_priority: Priority::Normal,
arrival_time: Instant::now(),
deadline: None,
promotions: 0,
state: SequenceState::Waiting,
generated_tokens: Vec::new(),
seq_id: None,
ttft_ms: None,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.original_priority = priority;
self.effective_priority = priority;
self
}
pub fn with_deadline(mut self, deadline: Deadline) -> Self {
self.deadline = Some(deadline);
self
}
pub fn wait_time_ms(&self) -> u64 {
self.arrival_time.elapsed().as_millis() as u64
}
pub fn is_urgent(&self) -> bool {
if let Some(deadline) = &self.deadline {
let elapsed = self.wait_time_ms();
elapsed >= deadline.target_latency_ms / 2
} else {
false
}
}
pub fn is_expired(&self) -> bool {
if let Some(deadline) = &self.deadline {
if let Some(hard) = deadline.hard_deadline_ms {
return self.wait_time_ms() > hard;
}
}
false
}
pub fn urgency_score(&self) -> f64 {
if let Some(deadline) = &self.deadline {
let elapsed = self.wait_time_ms() as f64;
let target = deadline.target_latency_ms as f64;
if target > 0.0 {
elapsed / target
} else {
0.0
}
} else {
0.0
}
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.generated_tokens.len())
}
pub fn total_tokens(&self) -> usize {
self.input_ids.len() + self.generated_tokens.len()
}
}
#[derive(Debug)]
#[allow(dead_code)]
struct DynamicPriorityEntry {
request_id: u64,
effective_priority: Priority,
urgency_score: f64,
arrival_time: Instant,
}
impl PartialEq for DynamicPriorityEntry {
fn eq(&self, other: &Self) -> bool {
self.request_id == other.request_id
}
}
impl Eq for DynamicPriorityEntry {}
impl PartialOrd for DynamicPriorityEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DynamicPriorityEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.effective_priority.cmp(&other.effective_priority) {
std::cmp::Ordering::Equal => {
match self
.urgency_score
.partial_cmp(&other.urgency_score)
.unwrap_or(std::cmp::Ordering::Equal)
{
std::cmp::Ordering::Equal => {
other.arrival_time.cmp(&self.arrival_time)
},
ord => ord,
}
},
ord => ord,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DynamicSchedulerStats {
pub total_requests: u64,
pub completed_requests: u64,
pub sla_met: u64,
pub sla_missed: u64,
pub dropped_requests: u64,
pub promotions: u64,
pub avg_ttft_ms: f64,
pub p99_ttft_ms: f64,
pub tokens_by_priority: [u64; 4],
pub queue_depth_by_priority: [usize; 4],
}
pub struct DynamicPriorityScheduler {
config: DynamicPriorityConfig,
requests: HashMap<u64, DynamicRequest>,
priority_queues: [VecDeque<u64>; 4],
running: Vec<u64>,
next_request_id: u64,
stats: DynamicSchedulerStats,
ttft_samples: Vec<f64>,
batch_token_budget: usize,
}
impl DynamicPriorityScheduler {
pub fn new(batch_token_budget: usize) -> Self {
Self::with_config(batch_token_budget, DynamicPriorityConfig::default())
}
pub fn with_config(batch_token_budget: usize, config: DynamicPriorityConfig) -> Self {
Self {
config,
requests: HashMap::new(),
priority_queues: [
VecDeque::new(),
VecDeque::new(),
VecDeque::new(),
VecDeque::new(),
],
running: Vec::new(),
next_request_id: 0,
stats: DynamicSchedulerStats::default(),
ttft_samples: Vec::new(),
batch_token_budget,
}
}
pub fn add_request(
&mut self,
input_ids: Vec<u32>,
max_tokens: usize,
priority: Priority,
deadline: Option<Deadline>,
) -> u64 {
let request_id = self.next_request_id;
self.next_request_id += 1;
let mut request =
DynamicRequest::new(request_id, input_ids, max_tokens).with_priority(priority);
if let Some(d) = deadline {
request = request.with_deadline(d);
}
let queue_idx = priority as usize;
self.priority_queues[queue_idx].push_back(request_id);
self.requests.insert(request_id, request);
self.stats.total_requests += 1;
self.update_queue_depths();
request_id
}
pub fn add_simple_request(&mut self, input_ids: Vec<u32>, max_tokens: usize) -> u64 {
self.add_request(input_ids, max_tokens, Priority::Normal, None)
}
pub fn promote_aged_requests(&mut self) {
if !self.config.enable_age_promotion {
return;
}
let promotion_threshold = self.config.promotion_interval_ms;
let max_priority = self.config.max_promoted_priority;
for queue_idx in 0..3 {
let current_priority = match queue_idx {
0 => Priority::Low,
1 => Priority::Normal,
2 => Priority::High,
_ => continue,
};
if current_priority >= max_priority {
continue;
}
let mut to_promote = Vec::new();
for &request_id in &self.priority_queues[queue_idx] {
if let Some(request) = self.requests.get(&request_id) {
let promotions_time = promotion_threshold * (request.promotions as u64 + 1);
if request.wait_time_ms() >= promotions_time {
to_promote.push(request_id);
}
}
}
for request_id in to_promote {
self.promote_request(request_id);
}
}
}
fn promote_request(&mut self, request_id: u64) {
if let Some(request) = self.requests.get_mut(&request_id) {
let current_idx = request.effective_priority as usize;
let max_idx = self.config.max_promoted_priority as usize;
if current_idx < max_idx {
self.priority_queues[current_idx].retain(|&id| id != request_id);
let new_priority = match current_idx + 1 {
1 => Priority::Normal,
2 => Priority::High,
3 => Priority::Critical,
_ => return,
};
request.effective_priority = new_priority;
request.promotions += 1;
self.priority_queues[current_idx + 1].push_front(request_id);
self.stats.promotions += 1;
}
}
}
pub fn drop_expired(&mut self) -> Vec<u64> {
let mut dropped = Vec::new();
for queue in &mut self.priority_queues {
let mut to_remove = Vec::new();
for &request_id in queue.iter() {
if let Some(request) = self.requests.get(&request_id) {
if request.is_expired() {
to_remove.push(request_id);
}
}
}
for request_id in to_remove {
queue.retain(|&id| id != request_id);
if let Some(mut request) = self.requests.remove(&request_id) {
request.state = SequenceState::Failed;
dropped.push(request_id);
self.stats.dropped_requests += 1;
}
}
}
self.update_queue_depths();
dropped
}
pub fn schedule(&mut self, available_slots: usize) -> Vec<(u64, usize)> {
self.promote_aged_requests();
self.drop_expired();
let mut scheduled = Vec::new();
let mut remaining_budget = self.batch_token_budget;
let mut remaining_slots = available_slots;
let budgets: [usize; 4] = if self.config.enable_fair_share {
self.config
.priority_budgets
.map(|b| (b * self.batch_token_budget as f64) as usize)
} else {
[
remaining_budget,
remaining_budget,
remaining_budget,
remaining_budget,
]
};
for queue_idx in (0..4).rev() {
if remaining_slots == 0 || remaining_budget == 0 {
break;
}
let queue = &mut self.priority_queues[queue_idx];
let mut priority_budget = budgets[queue_idx].min(remaining_budget);
if self.config.enable_deadline_scheduling {
let mut sorted: Vec<_> = queue.iter().copied().collect();
sorted.sort_by(|&a, &b| {
let req_a = self.requests.get(&a);
let req_b = self.requests.get(&b);
match (req_a, req_b) {
(Some(ra), Some(rb)) => rb
.urgency_score()
.partial_cmp(&ra.urgency_score())
.unwrap_or(std::cmp::Ordering::Equal),
_ => std::cmp::Ordering::Equal,
}
});
*queue = sorted.into_iter().collect();
}
let mut scheduled_from_queue = Vec::new();
for &request_id in queue.iter() {
if remaining_slots == 0 || priority_budget < self.config.min_tokens_per_request {
break;
}
if let Some(request) = self.requests.get(&request_id) {
let tokens_needed = request.remaining_tokens().max(1);
let tokens_to_allocate = tokens_needed
.min(priority_budget)
.max(self.config.min_tokens_per_request);
if tokens_to_allocate > 0 {
scheduled.push((request_id, tokens_to_allocate));
scheduled_from_queue.push(request_id);
priority_budget = priority_budget.saturating_sub(tokens_to_allocate);
remaining_budget = remaining_budget.saturating_sub(tokens_to_allocate);
remaining_slots -= 1;
self.stats.tokens_by_priority[queue_idx] += tokens_to_allocate as u64;
}
}
}
for request_id in scheduled_from_queue {
queue.retain(|&id| id != request_id);
if let Some(request) = self.requests.get_mut(&request_id) {
request.state = SequenceState::Running;
self.running.push(request_id);
if request.ttft_ms.is_none() {
let ttft = request.wait_time_ms() as f64;
request.ttft_ms = Some(ttft);
self.ttft_samples.push(ttft);
}
}
}
}
self.update_queue_depths();
scheduled
}
pub fn complete_request(&mut self, request_id: u64) -> Option<DynamicRequest> {
self.running.retain(|&id| id != request_id);
if let Some(mut request) = self.requests.remove(&request_id) {
request.state = SequenceState::Completed;
self.stats.completed_requests += 1;
if let Some(deadline) = &request.deadline {
let elapsed = request.wait_time_ms();
if elapsed <= deadline.target_latency_ms {
self.stats.sla_met += 1;
} else {
self.stats.sla_missed += 1;
}
}
self.update_ttft_stats();
Some(request)
} else {
None
}
}
fn update_ttft_stats(&mut self) {
if self.ttft_samples.is_empty() {
return;
}
let sum: f64 = self.ttft_samples.iter().sum();
self.stats.avg_ttft_ms = sum / self.ttft_samples.len() as f64;
let mut sorted = self.ttft_samples.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let p99_idx = ((sorted.len() as f64) * 0.99) as usize;
self.stats.p99_ttft_ms = sorted
.get(p99_idx.min(sorted.len() - 1))
.copied()
.unwrap_or(0.0);
}
fn update_queue_depths(&mut self) {
for (i, queue) in self.priority_queues.iter().enumerate() {
self.stats.queue_depth_by_priority[i] = queue.len();
}
}
pub fn get_request(&self, request_id: u64) -> Option<&DynamicRequest> {
self.requests.get(&request_id)
}
pub fn stats(&self) -> &DynamicSchedulerStats {
&self.stats
}
pub fn config(&self) -> &DynamicPriorityConfig {
&self.config
}
pub fn waiting_count(&self) -> usize {
self.priority_queues.iter().map(VecDeque::len).sum()
}
pub fn running_count(&self) -> usize {
self.running.len()
}
pub fn sla_compliance_rate(&self) -> f64 {
let total = self.stats.sla_met + self.stats.sla_missed;
if total == 0 {
1.0
} else {
self.stats.sla_met as f64 / total as f64
}
}
pub fn queue_depth(&self, priority: Priority) -> usize {
self.priority_queues[priority as usize].len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkedPrefillConfig {
pub enabled: bool,
pub chunk_size: usize,
pub min_prompt_length: usize,
pub allow_decode_interleave: bool,
pub boost_partial_prefill: bool,
pub max_chunks: usize,
}
impl Default for ChunkedPrefillConfig {
fn default() -> Self {
Self {
enabled: true,
chunk_size: 512,
min_prompt_length: 256,
allow_decode_interleave: true,
boost_partial_prefill: true,
max_chunks: 16,
}
}
}
impl ChunkedPrefillConfig {
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn low_latency() -> Self {
Self {
enabled: true,
chunk_size: 128,
min_prompt_length: 64,
allow_decode_interleave: true,
boost_partial_prefill: true,
max_chunks: 32,
}
}
pub fn high_throughput() -> Self {
Self {
enabled: true,
chunk_size: 1024,
min_prompt_length: 512,
allow_decode_interleave: false,
boost_partial_prefill: false,
max_chunks: 8,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkedPrefillState {
pub seq_id: u64,
pub total_tokens: usize,
pub processed_tokens: usize,
pub current_chunk: usize,
pub total_chunks: usize,
pub start_time_ms: u64,
pub chunk_latencies: Vec<u64>,
}
impl ChunkedPrefillState {
pub fn new(seq_id: u64, total_tokens: usize, chunk_size: usize) -> Self {
let total_chunks = total_tokens.div_ceil(chunk_size);
Self {
seq_id,
total_tokens,
processed_tokens: 0,
current_chunk: 0,
total_chunks,
start_time_ms: 0,
chunk_latencies: Vec::with_capacity(total_chunks),
}
}
pub fn next_chunk(&self, chunk_size: usize) -> std::ops::Range<usize> {
let start = self.processed_tokens;
let end = (start + chunk_size).min(self.total_tokens);
start..end
}
pub fn advance(&mut self, tokens_processed: usize, latency_ms: u64) {
self.processed_tokens += tokens_processed;
self.current_chunk += 1;
self.chunk_latencies.push(latency_ms);
}
pub fn is_complete(&self) -> bool {
self.processed_tokens >= self.total_tokens
}
pub fn progress(&self) -> f64 {
if self.total_tokens == 0 {
100.0
} else {
(self.processed_tokens as f64 / self.total_tokens as f64) * 100.0
}
}
pub fn remaining_tokens(&self) -> usize {
self.total_tokens.saturating_sub(self.processed_tokens)
}
pub fn avg_chunk_latency_ms(&self) -> f64 {
if self.chunk_latencies.is_empty() {
0.0
} else {
self.chunk_latencies.iter().sum::<u64>() as f64 / self.chunk_latencies.len() as f64
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChunkedPrefillStats {
pub chunked_sequences: u64,
pub bypassed_sequences: u64,
pub chunks_processed: u64,
pub decode_interleaves: u64,
pub total_chunk_latency_ms: u64,
pub max_chunk_latency_ms: u64,
pub prefix_cache_hits: u64,
}
impl ChunkedPrefillStats {
pub fn avg_chunk_latency_ms(&self) -> f64 {
if self.chunks_processed == 0 {
0.0
} else {
self.total_chunk_latency_ms as f64 / self.chunks_processed as f64
}
}
pub fn chunking_rate(&self) -> f64 {
let total = self.chunked_sequences + self.bypassed_sequences;
if total == 0 {
0.0
} else {
self.chunked_sequences as f64 / total as f64
}
}
}
pub struct ChunkedPrefillScheduler {
config: ChunkedPrefillConfig,
active_prefills: HashMap<u64, ChunkedPrefillState>,
prefill_queue: VecDeque<u64>,
stats: ChunkedPrefillStats,
next_seq_id: u64,
}
impl ChunkedPrefillScheduler {
pub fn new(config: ChunkedPrefillConfig) -> Self {
Self {
config,
active_prefills: HashMap::new(),
prefill_queue: VecDeque::new(),
stats: ChunkedPrefillStats::default(),
next_seq_id: 0,
}
}
pub fn submit(&mut self, prompt_tokens: usize) -> (u64, bool) {
let seq_id = self.next_seq_id;
self.next_seq_id += 1;
let use_chunking = self.config.enabled && prompt_tokens >= self.config.min_prompt_length;
if use_chunking {
let state = ChunkedPrefillState::new(seq_id, prompt_tokens, self.config.chunk_size);
self.active_prefills.insert(seq_id, state);
self.prefill_queue.push_back(seq_id);
self.stats.chunked_sequences += 1;
} else {
self.stats.bypassed_sequences += 1;
}
(seq_id, use_chunking)
}
pub fn next_chunk(&mut self) -> Option<(u64, std::ops::Range<usize>)> {
while let Some(&seq_id) = self.prefill_queue.front() {
if let Some(state) = self.active_prefills.get(&seq_id) {
if !state.is_complete() {
let range = state.next_chunk(self.config.chunk_size);
return Some((seq_id, range));
}
}
self.prefill_queue.pop_front();
}
None
}
pub fn complete_chunk(&mut self, seq_id: u64, tokens_processed: usize, latency_ms: u64) {
if let Some(state) = self.active_prefills.get_mut(&seq_id) {
state.advance(tokens_processed, latency_ms);
self.stats.chunks_processed += 1;
self.stats.total_chunk_latency_ms += latency_ms;
self.stats.max_chunk_latency_ms = self.stats.max_chunk_latency_ms.max(latency_ms);
if state.is_complete() {
if let Some(pos) = self.prefill_queue.iter().position(|&id| id == seq_id) {
self.prefill_queue.remove(pos);
}
} else if self.config.boost_partial_prefill {
} else {
if let Some(pos) = self.prefill_queue.iter().position(|&id| id == seq_id) {
self.prefill_queue.remove(pos);
self.prefill_queue.push_back(seq_id);
}
}
}
}
pub fn record_decode_interleave(&mut self) {
self.stats.decode_interleaves += 1;
}
pub fn should_interleave_decode(&self) -> bool {
self.config.allow_decode_interleave && !self.prefill_queue.is_empty()
}
pub fn get_state(&self, seq_id: u64) -> Option<&ChunkedPrefillState> {
self.active_prefills.get(&seq_id)
}
pub fn has_pending_prefill(&self, seq_id: u64) -> bool {
self.active_prefills
.get(&seq_id)
.is_some_and(|s| !s.is_complete())
}
pub fn remove(&mut self, seq_id: u64) -> Option<ChunkedPrefillState> {
if let Some(pos) = self.prefill_queue.iter().position(|&id| id == seq_id) {
self.prefill_queue.remove(pos);
}
self.active_prefills.remove(&seq_id)
}
pub fn pending_count(&self) -> usize {
self.active_prefills
.values()
.filter(|s| !s.is_complete())
.count()
}
pub fn queue_len(&self) -> usize {
self.prefill_queue.len()
}
pub fn stats(&self) -> &ChunkedPrefillStats {
&self.stats
}
pub fn config(&self) -> &ChunkedPrefillConfig {
&self.config
}
pub fn clear(&mut self) {
self.active_prefills.clear();
self.prefill_queue.clear();
}
pub fn record_prefix_cache_hit(&mut self, tokens_saved: usize) {
self.stats.prefix_cache_hits += tokens_saved as u64;
}
}
impl Default for ChunkedPrefillScheduler {
fn default() -> Self {
Self::new(ChunkedPrefillConfig::default())
}
}
#[cfg(all(test, feature = "heavy-tests"))]
mod tests {
use super::*;
#[test]
fn test_priority_ordering() {
assert!(Priority::Critical > Priority::High);
assert!(Priority::High > Priority::Normal);
assert!(Priority::Normal > Priority::Low);
}
#[test]
fn test_priority_default() {
assert_eq!(Priority::default(), Priority::Normal);
}
#[test]
fn test_request_new() {
let request = SchedulerRequest::new(1, vec![1, 2, 3], 10);
assert_eq!(request.request_id, 1);
assert_eq!(request.input_ids.len(), 3);
assert_eq!(request.max_tokens, 10);
assert_eq!(request.state, SequenceState::Waiting);
}
#[test]
fn test_request_with_priority() {
let request = SchedulerRequest::new(1, vec![1], 10).with_priority(Priority::High);
assert_eq!(request.priority, Priority::High);
}
#[test]
fn test_request_total_tokens() {
let mut request = SchedulerRequest::new(1, vec![1, 2, 3], 10);
assert_eq!(request.total_tokens(), 3);
request.generated_tokens = vec![4, 5];
assert_eq!(request.total_tokens(), 5);
}
#[test]
fn test_request_remaining_tokens() {
let mut request = SchedulerRequest::new(1, vec![1, 2, 3], 10);
assert_eq!(request.remaining_tokens(), 10);
request.generated_tokens = vec![4, 5, 6];
assert_eq!(request.remaining_tokens(), 7);
}
#[test]
fn test_request_is_complete() {
let mut request = SchedulerRequest::new(1, vec![1], 3);
assert!(!request.is_complete(0));
request.generated_tokens = vec![2, 3, 4];
assert!(request.is_complete(0));
let mut request2 = SchedulerRequest::new(2, vec![1], 10);
request2.generated_tokens = vec![2, 0]; assert!(request2.is_complete(0));
}
#[test]
fn test_scheduler_output_total_tokens() {
let output = SchedulerOutput {
num_prefill_tokens: 100,
num_decode_tokens: 10,
..Default::default()
};
assert_eq!(output.total_tokens(), 110);
}
#[test]
fn test_scheduler_output_is_empty() {
let output = SchedulerOutput::default();
assert!(output.is_empty());
}
#[test]
fn test_scheduler_new() {
let scheduler = Scheduler::new(32, 1000);
assert_eq!(scheduler.max_batch_size, 32);
assert_eq!(scheduler.max_queue_size, 1000);
assert_eq!(scheduler.waiting_count(), 0);
}
#[test]
fn test_scheduler_add_request() {
let mut scheduler = Scheduler::new(32, 1000);
let request_id = scheduler.add_request(vec![1, 2, 3], 10).unwrap();
assert_eq!(request_id, 0);
assert_eq!(scheduler.waiting_count(), 1);
assert_eq!(scheduler.stats().total_requests, 1);
}
#[test]
fn test_scheduler_add_request_queue_full() {
let mut scheduler = Scheduler::new(32, 1);
let _ = scheduler.add_request(vec![1], 10).unwrap();
let result = scheduler.add_request(vec![2], 10);
assert!(matches!(result, Err(SchedulerError::QueueFull { .. })));
}
#[test]
fn test_scheduler_schedule() {
let mut scheduler = Scheduler::new(32, 1000);
let mut kv_cache = PagedKvCache::new(100, 16, 8, 64);
let _ = scheduler.add_request(vec![1, 2, 3], 10).unwrap();
let output = scheduler.schedule(&mut kv_cache, 0).unwrap();
assert_eq!(output.scheduled_request_ids.len(), 1);
assert_eq!(scheduler.running_count(), 1);
assert_eq!(scheduler.waiting_count(), 0);
}
#[test]
fn test_scheduler_update_after_iteration() {
let mut scheduler = Scheduler::new(32, 1000);
let mut kv_cache = PagedKvCache::new(100, 16, 8, 64);
let request_id = scheduler.add_request(vec![1], 10).unwrap();
let _ = scheduler.schedule(&mut kv_cache, 0).unwrap();
let mut generated = HashMap::new();
generated.insert(request_id, 42u32);
scheduler.update_after_iteration(&generated);
let request = scheduler.get_request(request_id).unwrap();
assert_eq!(request.generated_tokens, vec![42]);
assert_eq!(request.iterations, 1);
}
#[test]
fn test_scheduler_complete_request() {
let mut scheduler = Scheduler::new(32, 1000);
let mut kv_cache = PagedKvCache::new(100, 16, 8, 64);
let request_id = scheduler.add_request(vec![1], 10).unwrap();
let _ = scheduler.schedule(&mut kv_cache, 0).unwrap();
scheduler.complete_request(request_id, &mut kv_cache);
assert_eq!(scheduler.running_count(), 0);
assert_eq!(scheduler.stats().completed_requests, 1);
}
#[test]
fn test_scheduler_priority_ordering() {
let mut scheduler = Scheduler::new(1, 1000); let mut kv_cache = PagedKvCache::new(100, 16, 8, 64);
let low_id = scheduler
.add_request_with_priority(vec![1], 10, Priority::Low)
.unwrap();
let _high_id = scheduler
.add_request_with_priority(vec![2], 10, Priority::High)
.unwrap();
let output = scheduler.schedule(&mut kv_cache, 0).unwrap();
assert_eq!(output.scheduled_request_ids.len(), 1);
let low_request = scheduler.get_request(low_id).unwrap();
assert!(
low_request.state == SequenceState::Waiting
|| low_request.state == SequenceState::Preempted
);
}
#[test]
fn test_scheduler_max_batch_size() {
let mut scheduler = Scheduler::new(2, 1000);
let mut kv_cache = PagedKvCache::new(100, 16, 8, 64);
let _ = scheduler.add_request(vec![1], 10).unwrap();
let _ = scheduler.add_request(vec![2], 10).unwrap();
let _ = scheduler.add_request(vec![3], 10).unwrap();
let output = scheduler.schedule(&mut kv_cache, 0).unwrap();
assert_eq!(output.scheduled_request_ids.len(), 2);
assert_eq!(scheduler.running_count(), 2);
assert_eq!(scheduler.waiting_count(), 1);
}
#[test]
fn test_scheduler_stats() {
let mut scheduler = Scheduler::new(32, 1000);
let mut kv_cache = PagedKvCache::new(100, 16, 8, 64);
let request_id = scheduler.add_request(vec![1], 10).unwrap();
let _ = scheduler.schedule(&mut kv_cache, 0).unwrap();
scheduler.complete_request(request_id, &mut kv_cache);
let stats = scheduler.stats();
assert_eq!(stats.total_requests, 1);
assert_eq!(stats.completed_requests, 1);
}
#[test]
fn test_scheduler_error_display() {
let err = SchedulerError::QueueFull { capacity: 100 };
assert!(err.to_string().contains("100"));
let err = SchedulerError::RequestNotFound(42);
assert!(err.to_string().contains("42"));
let err = SchedulerError::InvalidState("test".to_string());
assert!(err.to_string().contains("test"));
}
#[test]
fn test_sequence_state_variants() {
let states = [
SequenceState::Waiting,
SequenceState::Running,
SequenceState::Preempted,
SequenceState::Completed,
SequenceState::Failed,
];
for (i, s1) in states.iter().enumerate() {
for (j, s2) in states.iter().enumerate() {
if i == j {
assert_eq!(s1, s2);
} else {
assert_ne!(s1, s2);
}
}
}
}
#[test]
fn test_scheduler_stats_serialization() {
let stats = SchedulerStats {
total_requests: 100,
completed_requests: 90,
preemptions: 5,
avg_wait_time_ms: 10.5,
avg_ttft_ms: 50.0,
queue_depth: 10,
running_count: 8,
};
let json = serde_json::to_string(&stats).unwrap();
let parsed: SchedulerStats = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.total_requests, stats.total_requests);
assert_eq!(parsed.preemptions, stats.preemptions);
}
#[test]
fn test_slot_state_default() {
assert_eq!(SlotState::default(), SlotState::Idle);
}
#[test]
fn test_slot_new() {
let slot = Slot::new(0);
assert_eq!(slot.id, 0);
assert!(slot.is_idle());
assert!(!slot.is_generating());
assert!(slot.request_id.is_none());
}
#[test]
fn test_slot_assign() {
let mut slot = Slot::new(0);
slot.assign(42, vec![1, 2, 3], 10);
assert_eq!(slot.state, SlotState::Processing);
assert_eq!(slot.request_id, Some(42));
assert_eq!(slot.input_tokens, vec![1, 2, 3]);
assert_eq!(slot.max_tokens, 10);
assert!(!slot.is_idle());
}
#[test]
fn test_slot_start_generation() {
let mut slot = Slot::new(0);
slot.assign(1, vec![1], 10);
slot.start_generation(5.0);
assert_eq!(slot.state, SlotState::Generating);
assert!(slot.is_generating());
assert_eq!(slot.prompt_time_ms, 5.0);
assert!(slot.generation_start.is_some());
}
#[test]
fn test_slot_add_token() {
let mut slot = Slot::new(0);
slot.assign(1, vec![1], 10);
slot.start_generation(1.0);
slot.add_token(100);
slot.add_token(200);
assert_eq!(slot.generated_tokens, vec![100, 200]);
}
#[test]
fn test_slot_is_complete_max_tokens() {
let mut slot = Slot::new(0);
slot.assign(1, vec![1], 3);
slot.start_generation(1.0);
slot.add_token(100);
assert!(!slot.is_complete(999));
slot.add_token(200);
slot.add_token(300);
assert!(slot.is_complete(999)); }
#[test]
fn test_slot_is_complete_eos() {
let mut slot = Slot::new(0);
slot.assign(1, vec![1], 100);
slot.start_generation(1.0);
slot.add_token(100);
assert!(!slot.is_complete(999));
slot.add_token(999); assert!(slot.is_complete(999));
}
#[test]
fn test_slot_finish() {
let mut slot = Slot::new(0);
slot.assign(1, vec![1], 10);
slot.start_generation(1.0);
slot.add_token(100);
slot.finish();
assert!(slot.is_idle());
assert!(slot.request_id.is_none());
assert!(slot.seq_id.is_none());
}
#[test]
fn test_slot_manager_new() {
let manager = SlotManager::new(4, 2048);
assert_eq!(manager.num_slots(), 4);
assert_eq!(manager.num_idle_slots(), 4);
assert_eq!(manager.num_active_slots(), 0);
assert_eq!(manager.max_context_length, 2048);
}
#[test]
fn test_slot_manager_assign_request() {
let mut manager = SlotManager::new(4, 2048);
let result = manager.assign_request(vec![1, 2, 3], 10);
assert!(result.is_some());
let (slot_id, request_id) = result.unwrap();
assert_eq!(slot_id, 0);
assert_eq!(request_id, 0);
assert_eq!(manager.num_idle_slots(), 3);
assert_eq!(manager.num_active_slots(), 1);
}
#[test]
fn test_slot_manager_no_slots_available() {
let mut manager = SlotManager::new(2, 2048);
manager.assign_request(vec![1], 10).unwrap();
manager.assign_request(vec![2], 10).unwrap();
let result = manager.assign_request(vec![3], 10);
assert!(result.is_none());
}
#[test]
fn test_slot_manager_utilization() {
let mut manager = SlotManager::new(4, 2048);
assert_eq!(manager.utilization(), 0.0);
manager.assign_request(vec![1], 10);
assert!((manager.utilization() - 0.25).abs() < 0.01);
manager.assign_request(vec![2], 10);
assert!((manager.utilization() - 0.5).abs() < 0.01);
}
#[test]
fn test_slot_manager_batch_slots() {
let mut manager = SlotManager::new(4, 2048);
manager.assign_request(vec![1], 10);
manager.assign_request(vec![2], 10);
manager.get_slot_mut(0).unwrap().start_generation(1.0);
let batch = manager.batch_slots();
assert_eq!(batch, vec![0]); }
#[test]
fn test_slot_manager_get_slot() {
let manager = SlotManager::new(4, 2048);
assert!(manager.get_slot(0).is_some());
assert!(manager.get_slot(3).is_some());
assert!(manager.get_slot(4).is_none()); }
#[test]
fn test_slot_manager_active_slots() {
let mut manager = SlotManager::new(4, 2048);
manager.assign_request(vec![1], 10);
manager.assign_request(vec![2], 10);
let active: Vec<_> = manager.active_slots().collect();
assert_eq!(active.len(), 2);
}
#[test]
fn test_batch_type_default() {
assert_eq!(BatchType::default(), BatchType::Decode);
}
#[test]
fn test_batch_token_new() {
let token = BatchToken::new(42, 0, 5, true);
assert_eq!(token.token_id, 42);
assert_eq!(token.seq_idx, 0);
assert_eq!(token.seq_pos, 5);
assert!(token.is_prompt);
}
#[test]
fn test_micro_batch_new() {
let batch = MicroBatch::new();
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
assert_eq!(batch.num_sequences(), 0);
assert!(batch.is_decode()); }
#[test]
fn test_micro_batch_add_tokens() {
let mut batch = MicroBatch::new();
batch.add_token(BatchToken::new(1, 0, 0, true));
batch.add_token(BatchToken::new(2, 0, 1, true));
assert_eq!(batch.len(), 2);
assert_eq!(batch.num_sequences(), 1);
assert!(batch.is_prefill());
assert_eq!(batch.n_prompt_tokens, 2);
assert_eq!(batch.n_decode_tokens, 0);
}
#[test]
fn test_micro_batch_mixed_type() {
let mut batch = MicroBatch::new();
batch.add_token(BatchToken::new(1, 0, 0, true)); batch.add_token(BatchToken::new(2, 1, 5, false));
assert!(batch.is_mixed());
assert_eq!(batch.n_prompt_tokens, 1);
assert_eq!(batch.n_decode_tokens, 1);
}
#[test]
fn test_micro_batch_token_ids() {
let mut batch = MicroBatch::new();
batch.add_token(BatchToken::new(10, 0, 0, true));
batch.add_token(BatchToken::new(20, 0, 1, true));
batch.add_token(BatchToken::new(30, 0, 2, true));
assert_eq!(batch.token_ids(), vec![10, 20, 30]);
}
#[test]
fn test_micro_batch_positions() {
let mut batch = MicroBatch::new();
batch.add_token(BatchToken::new(10, 0, 0, true));
batch.add_token(BatchToken::new(20, 0, 1, true));
batch.add_token(BatchToken::new(30, 1, 5, false));
assert_eq!(batch.positions(), vec![0, 1, 5]);
}
#[test]
fn test_micro_batch_clear() {
let mut batch = MicroBatch::new();
batch.add_token(BatchToken::new(1, 0, 0, true));
batch.add_token(BatchToken::new(2, 1, 0, false));
batch.clear();
assert!(batch.is_empty());
assert_eq!(batch.n_prompt_tokens, 0);
assert_eq!(batch.n_decode_tokens, 0);
assert_eq!(batch.max_seq_len, 0);
}
#[test]
fn test_micro_batch_max_seq_len() {
let mut batch = MicroBatch::new();
batch.add_token(BatchToken::new(1, 0, 0, true));
batch.add_token(BatchToken::new(2, 0, 10, true));
assert_eq!(batch.max_seq_len, 11); }
#[test]
fn test_sequence_batch_entry_new() {
let entry = SequenceBatchEntry::new(0, 1, 100);
assert_eq!(entry.seq_idx, 0);
assert_eq!(entry.slot_id, 1);
assert_eq!(entry.request_id, 100);
assert!(entry.is_prefill);
assert_eq!(entry.position, 0);
}
#[test]
fn test_sequence_batch_entry_builder() {
let entry = SequenceBatchEntry::new(0, 1, 100)
.with_tokens(vec![1, 2, 3])
.at_position(5)
.decoding();
assert_eq!(entry.tokens, vec![1, 2, 3]);
assert_eq!(entry.position, 5);
assert!(!entry.is_prefill);
}
#[test]
fn test_sequence_batch_new() {
let batch = SequenceBatch::new(8);
assert!(batch.is_empty());
assert!(!batch.is_full());
assert_eq!(batch.max_batch_size, 8);
}
#[test]
fn test_sequence_batch_add_remove() {
let mut batch = SequenceBatch::new(4);
let entry = SequenceBatchEntry::new(0, 0, 1);
assert!(batch.add_sequence(entry));
assert_eq!(batch.len(), 1);
let removed = batch.remove_sequence(0);
assert!(removed.is_some());
assert!(batch.is_empty());
}
#[test]
fn test_sequence_batch_full() {
let mut batch = SequenceBatch::new(2);
batch.add_sequence(SequenceBatchEntry::new(0, 0, 1));
batch.add_sequence(SequenceBatchEntry::new(1, 1, 2));
assert!(batch.is_full());
assert!(!batch.add_sequence(SequenceBatchEntry::new(2, 2, 3)));
}
#[test]
fn test_sequence_batch_prefill_decode_counts() {
let mut batch = SequenceBatch::new(4);
batch.add_sequence(SequenceBatchEntry::new(0, 0, 1)); batch.add_sequence(SequenceBatchEntry::new(1, 1, 2).decoding()); batch.add_sequence(SequenceBatchEntry::new(2, 2, 3).decoding());
assert_eq!(batch.num_prefill(), 1);
assert_eq!(batch.num_decode(), 2);
}
#[test]
fn test_sequence_batch_utilization() {
let mut batch = SequenceBatch::new(4);
assert_eq!(batch.utilization, 0.0);
batch.add_sequence(SequenceBatchEntry::new(0, 0, 1));
assert!((batch.utilization - 0.25).abs() < 0.01);
batch.add_sequence(SequenceBatchEntry::new(1, 1, 2));
assert!((batch.utilization - 0.5).abs() < 0.01);
}
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.max_ubatch_tokens, 512);
assert_eq!(config.max_sbatch_sequences, 8);
assert!(config.prefer_pure_decode);
assert!(config.dynamic_batching);
}
#[test]
fn test_batch_config_builder() {
let config = BatchConfig::default()
.with_max_tokens(1024)
.with_max_sequences(16);
assert_eq!(config.max_ubatch_tokens, 1024);
assert_eq!(config.max_sbatch_sequences, 16);
}
#[test]
fn test_batch_scheduler_new() {
let scheduler = BatchScheduler::new();
assert!(scheduler.has_capacity());
assert_eq!(scheduler.num_sequences(), 0);
assert_eq!(scheduler.utilization(), 0.0);
}
#[test]
fn test_batch_scheduler_add_sequence() {
let mut scheduler = BatchScheduler::new();
let seq_idx = scheduler.add_sequence(0, 1, vec![10, 20, 30]);
assert!(seq_idx.is_some());
assert_eq!(seq_idx.unwrap(), 0);
assert_eq!(scheduler.num_sequences(), 1);
}
#[test]
fn test_batch_scheduler_complete_sequence() {
let mut scheduler = BatchScheduler::new();
let seq_idx = scheduler.add_sequence(0, 1, vec![10, 20]).unwrap();
assert_eq!(scheduler.num_sequences(), 1);
let completed = scheduler.complete_sequence(seq_idx);
assert!(completed.is_some());
assert_eq!(scheduler.num_sequences(), 0);
}
#[test]
fn test_batch_scheduler_start_decode() {
let mut scheduler = BatchScheduler::new();
let seq_idx = scheduler.add_sequence(0, 1, vec![10, 20, 30]).unwrap();
assert!(scheduler.sbatch().get(seq_idx).unwrap().is_prefill);
assert!(scheduler.start_decode(seq_idx, 3));
let entry = scheduler.sbatch().get(seq_idx).unwrap();
assert!(!entry.is_prefill);
assert_eq!(entry.position, 3);
assert!(entry.tokens.is_empty()); }
#[test]
fn test_batch_scheduler_create_ubatch_prefill() {
let mut scheduler = BatchScheduler::new();
scheduler.add_sequence(0, 1, vec![10, 20, 30]);
let ubatch = scheduler.create_ubatch();
assert!(ubatch.is_prefill());
assert_eq!(ubatch.len(), 3);
assert_eq!(ubatch.token_ids(), vec![10, 20, 30]);
}
#[test]
fn test_batch_scheduler_create_ubatch_decode() {
let mut scheduler = BatchScheduler::new();
let seq_idx = scheduler.add_sequence(0, 1, vec![10, 20, 30]).unwrap();
scheduler.start_decode(seq_idx, 3);
let ubatch = scheduler.create_ubatch();
assert!(ubatch.is_decode());
assert_eq!(ubatch.len(), 1);
}
#[test]
fn test_batch_scheduler_stats() {
let mut scheduler = BatchScheduler::new();
scheduler.add_sequence(0, 1, vec![10, 20, 30]);
scheduler.create_ubatch();
let stats = scheduler.stats();
assert_eq!(stats.ubatches_created, 1);
assert_eq!(stats.tokens_processed, 3);
assert_eq!(stats.prefill_tokens, 3);
}
#[test]
fn test_batch_scheduler_capacity() {
let config = BatchConfig::default().with_max_sequences(2);
let mut scheduler = BatchScheduler::with_config(config);
scheduler.add_sequence(0, 1, vec![1]);
scheduler.add_sequence(1, 2, vec![2]);
assert!(!scheduler.has_capacity());
assert!(scheduler.add_sequence(2, 3, vec![3]).is_none());
}
#[test]
fn test_batch_stats_default() {
let stats = BatchStats::default();
assert_eq!(stats.ubatches_created, 0);
assert_eq!(stats.tokens_processed, 0);
assert_eq!(stats.avg_ubatch_size, 0.0);
}
#[test]
fn test_deadline_default() {
let deadline = Deadline::default();
assert_eq!(deadline.target_latency_ms, 1000);
assert!(deadline.hard_deadline_ms.is_none());
assert!((deadline.sla_target - 0.99).abs() < 0.001);
}
#[test]
fn test_deadline_with_target() {
let deadline = Deadline::with_target(500);
assert_eq!(deadline.target_latency_ms, 500);
}
#[test]
fn test_deadline_strict() {
let deadline = Deadline::strict(100, 200);
assert_eq!(deadline.target_latency_ms, 100);
assert_eq!(deadline.hard_deadline_ms, Some(200));
assert!((deadline.sla_target - 1.0).abs() < 0.001);
}
#[test]
fn test_dynamic_priority_config_default() {
let config = DynamicPriorityConfig::default();
assert!(config.enable_age_promotion);
assert_eq!(config.promotion_interval_ms, 5000);
assert_eq!(config.max_promoted_priority, Priority::High);
assert!(config.enable_deadline_scheduling);
assert!(config.enable_fair_share);
}
#[test]
fn test_dynamic_priority_config_builder() {
let config = DynamicPriorityConfig::with_budgets([0.1, 0.2, 0.3, 0.4])
.no_promotion()
.with_promotion_interval(1000);
assert!(!config.enable_age_promotion);
assert_eq!(config.promotion_interval_ms, 1000);
assert!((config.priority_budgets[0] - 0.1).abs() < 0.001);
}
#[test]
fn test_dynamic_request_new() {
let request = DynamicRequest::new(0, vec![1, 2, 3], 10);
assert_eq!(request.request_id, 0);
assert_eq!(request.input_ids.len(), 3);
assert_eq!(request.max_tokens, 10);
assert_eq!(request.original_priority, Priority::Normal);
assert_eq!(request.effective_priority, Priority::Normal);
assert_eq!(request.promotions, 0);
}
#[test]
fn test_dynamic_request_with_priority() {
let request = DynamicRequest::new(0, vec![1], 10).with_priority(Priority::High);
assert_eq!(request.original_priority, Priority::High);
assert_eq!(request.effective_priority, Priority::High);
}
#[test]
fn test_dynamic_request_with_deadline() {
let request = DynamicRequest::new(0, vec![1], 10).with_deadline(Deadline::with_target(500));
assert!(request.deadline.is_some());
assert_eq!(request.deadline.unwrap().target_latency_ms, 500);
}
#[test]
fn test_dynamic_request_urgency_no_deadline() {
let request = DynamicRequest::new(0, vec![1], 10);
assert_eq!(request.urgency_score(), 0.0);
assert!(!request.is_urgent());
}
#[test]
fn test_dynamic_request_remaining_tokens() {
let mut request = DynamicRequest::new(0, vec![1], 10);
assert_eq!(request.remaining_tokens(), 10);
request.generated_tokens = vec![2, 3, 4];
assert_eq!(request.remaining_tokens(), 7);
}
#[test]
fn test_dynamic_request_total_tokens() {
let mut request = DynamicRequest::new(0, vec![1, 2, 3], 10);
assert_eq!(request.total_tokens(), 3);
request.generated_tokens = vec![4, 5];
assert_eq!(request.total_tokens(), 5);
}
#[test]
fn test_dynamic_scheduler_new() {
let scheduler = DynamicPriorityScheduler::new(1024);
assert_eq!(scheduler.waiting_count(), 0);
assert_eq!(scheduler.running_count(), 0);
assert_eq!(scheduler.batch_token_budget, 1024);
}
#[test]
fn test_dynamic_scheduler_add_request() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
let id1 = scheduler.add_request(vec![1, 2, 3], 10, Priority::Normal, None);
assert_eq!(id1, 0);
assert_eq!(scheduler.waiting_count(), 1);
assert_eq!(scheduler.queue_depth(Priority::Normal), 1);
let id2 = scheduler.add_request(vec![4, 5], 5, Priority::High, None);
assert_eq!(id2, 1);
assert_eq!(scheduler.waiting_count(), 2);
assert_eq!(scheduler.queue_depth(Priority::High), 1);
}
#[test]
fn test_dynamic_scheduler_add_simple_request() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
let id = scheduler.add_simple_request(vec![1, 2], 5);
assert_eq!(id, 0);
assert_eq!(scheduler.queue_depth(Priority::Normal), 1);
}
#[test]
fn test_dynamic_scheduler_schedule_priority_order() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
let low_id = scheduler.add_request(vec![1], 5, Priority::Low, None);
let normal_id = scheduler.add_request(vec![2], 5, Priority::Normal, None);
let high_id = scheduler.add_request(vec![3], 5, Priority::High, None);
let batch = scheduler.schedule(2);
assert_eq!(batch.len(), 2);
let scheduled_ids: Vec<_> = batch.iter().map(|(id, _)| *id).collect();
assert!(scheduled_ids.contains(&high_id));
assert!(scheduled_ids.contains(&normal_id));
assert!(!scheduled_ids.contains(&low_id));
}
#[test]
fn test_dynamic_scheduler_complete_request() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
let id = scheduler.add_simple_request(vec![1], 5);
let _ = scheduler.schedule(1);
assert_eq!(scheduler.running_count(), 1);
let completed = scheduler.complete_request(id);
assert!(completed.is_some());
assert_eq!(scheduler.running_count(), 0);
assert_eq!(scheduler.stats().completed_requests, 1);
}
#[test]
fn test_dynamic_scheduler_sla_compliance() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
let id = scheduler.add_request(
vec![1],
5,
Priority::Normal,
Some(Deadline::with_target(100_000)), );
let _ = scheduler.schedule(1);
let _ = scheduler.complete_request(id);
assert_eq!(scheduler.stats().sla_met, 1);
assert_eq!(scheduler.stats().sla_missed, 0);
assert!((scheduler.sla_compliance_rate() - 1.0).abs() < 0.001);
}
#[test]
fn test_dynamic_scheduler_stats() {
let scheduler = DynamicPriorityScheduler::new(1024);
let stats = scheduler.stats();
assert_eq!(stats.total_requests, 0);
assert_eq!(stats.completed_requests, 0);
assert_eq!(stats.promotions, 0);
assert_eq!(stats.dropped_requests, 0);
}
#[test]
fn test_dynamic_scheduler_stats_serialization() {
let stats = DynamicSchedulerStats {
total_requests: 100,
completed_requests: 90,
sla_met: 85,
sla_missed: 5,
dropped_requests: 10,
promotions: 20,
avg_ttft_ms: 50.5,
p99_ttft_ms: 200.0,
tokens_by_priority: [100, 500, 300, 100],
queue_depth_by_priority: [5, 10, 3, 1],
};
let json = serde_json::to_string(&stats).unwrap();
let parsed: DynamicSchedulerStats = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.total_requests, 100);
assert_eq!(parsed.sla_met, 85);
}
#[test]
fn test_dynamic_scheduler_get_request() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
let id = scheduler.add_simple_request(vec![1, 2, 3], 10);
let request = scheduler.get_request(id);
assert!(request.is_some());
assert_eq!(request.unwrap().input_ids, vec![1, 2, 3]);
assert!(scheduler.get_request(999).is_none());
}
#[test]
fn test_dynamic_scheduler_config() {
let config = DynamicPriorityConfig::default().no_promotion();
let scheduler = DynamicPriorityScheduler::with_config(1024, config);
assert!(!scheduler.config().enable_age_promotion);
}
#[test]
fn test_dynamic_scheduler_queue_depths() {
let mut scheduler = DynamicPriorityScheduler::new(1024);
scheduler.add_request(vec![1], 5, Priority::Low, None);
scheduler.add_request(vec![2], 5, Priority::Low, None);
scheduler.add_request(vec![3], 5, Priority::Normal, None);
scheduler.add_request(vec![4], 5, Priority::High, None);
scheduler.add_request(vec![5], 5, Priority::Critical, None);
assert_eq!(scheduler.queue_depth(Priority::Low), 2);
assert_eq!(scheduler.queue_depth(Priority::Normal), 1);
assert_eq!(scheduler.queue_depth(Priority::High), 1);
assert_eq!(scheduler.queue_depth(Priority::Critical), 1);
assert_eq!(scheduler.waiting_count(), 5);
}
#[test]
fn test_dynamic_scheduler_token_budget_allocation() {
let mut scheduler = DynamicPriorityScheduler::new(100);
scheduler.add_request(vec![1], 50, Priority::Low, None);
scheduler.add_request(vec![2], 50, Priority::Normal, None);
scheduler.add_request(vec![3], 50, Priority::High, None);
scheduler.add_request(vec![4], 50, Priority::Critical, None);
let batch = scheduler.schedule(4);
assert_eq!(batch.len(), 4);
let stats = scheduler.stats();
assert!(stats.tokens_by_priority[3] > 0); assert!(stats.tokens_by_priority[2] > 0); }
#[test]
fn test_chunked_prefill_config_default() {
let config = ChunkedPrefillConfig::default();
assert!(config.enabled);
assert_eq!(config.chunk_size, 512);
assert_eq!(config.min_prompt_length, 256);
assert!(config.allow_decode_interleave);
assert!(config.boost_partial_prefill);
assert_eq!(config.max_chunks, 16);
}
#[test]
fn test_chunked_prefill_config_disabled() {
let config = ChunkedPrefillConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_chunked_prefill_config_low_latency() {
let config = ChunkedPrefillConfig::low_latency();
assert!(config.enabled);
assert_eq!(config.chunk_size, 128);
assert_eq!(config.min_prompt_length, 64);
}
#[test]
fn test_chunked_prefill_config_high_throughput() {
let config = ChunkedPrefillConfig::high_throughput();
assert!(config.enabled);
assert_eq!(config.chunk_size, 1024);
assert!(!config.allow_decode_interleave);
}
#[test]
fn test_chunked_prefill_config_with_chunk_size() {
let config = ChunkedPrefillConfig::default().with_chunk_size(256);
assert_eq!(config.chunk_size, 256);
}
#[test]
fn test_chunked_prefill_state_new() {
let state = ChunkedPrefillState::new(1, 1000, 512);
assert_eq!(state.seq_id, 1);
assert_eq!(state.total_tokens, 1000);
assert_eq!(state.processed_tokens, 0);
assert_eq!(state.current_chunk, 0);
assert_eq!(state.total_chunks, 2); assert!(!state.is_complete());
}
#[test]
fn test_chunked_prefill_state_next_chunk() {
let state = ChunkedPrefillState::new(1, 1000, 512);
let range = state.next_chunk(512);
assert_eq!(range, 0..512);
}
#[test]
fn test_chunked_prefill_state_advance() {
let mut state = ChunkedPrefillState::new(1, 1000, 512);
state.advance(512, 50);
assert_eq!(state.processed_tokens, 512);
assert_eq!(state.current_chunk, 1);
assert_eq!(state.chunk_latencies.len(), 1);
assert_eq!(state.chunk_latencies[0], 50);
let range = state.next_chunk(512);
assert_eq!(range, 512..1000);
}
#[test]
fn test_chunked_prefill_state_completion() {
let mut state = ChunkedPrefillState::new(1, 1000, 512);
assert!(!state.is_complete());
state.advance(512, 50);
assert!(!state.is_complete());
state.advance(488, 40);
assert!(state.is_complete());
}
#[test]
fn test_chunked_prefill_state_progress() {
let mut state = ChunkedPrefillState::new(1, 1000, 512);
assert!((state.progress() - 0.0).abs() < 0.01);
state.advance(500, 50);
assert!((state.progress() - 50.0).abs() < 0.01);
state.advance(500, 50);
assert!((state.progress() - 100.0).abs() < 0.01);
}
#[test]
fn test_chunked_prefill_state_remaining_tokens() {
let mut state = ChunkedPrefillState::new(1, 1000, 512);
assert_eq!(state.remaining_tokens(), 1000);
state.advance(600, 50);
assert_eq!(state.remaining_tokens(), 400);
}
#[test]
fn test_chunked_prefill_state_avg_latency() {
let mut state = ChunkedPrefillState::new(1, 1000, 512);
assert_eq!(state.avg_chunk_latency_ms(), 0.0);
state.advance(512, 50);
assert_eq!(state.avg_chunk_latency_ms(), 50.0);
state.advance(488, 30);
assert_eq!(state.avg_chunk_latency_ms(), 40.0);
}
#[test]
fn test_chunked_prefill_state_zero_tokens() {
let state = ChunkedPrefillState::new(1, 0, 512);
assert!(state.is_complete());
assert_eq!(state.progress(), 100.0);
}
#[test]
fn test_chunked_prefill_stats_default() {
let stats = ChunkedPrefillStats::default();
assert_eq!(stats.chunked_sequences, 0);
assert_eq!(stats.bypassed_sequences, 0);
assert_eq!(stats.chunks_processed, 0);
assert_eq!(stats.avg_chunk_latency_ms(), 0.0);
assert_eq!(stats.chunking_rate(), 0.0);
}
#[test]
fn test_chunked_prefill_stats_avg_latency() {
let stats = ChunkedPrefillStats {
chunks_processed: 4,
total_chunk_latency_ms: 200,
..Default::default()
};
assert_eq!(stats.avg_chunk_latency_ms(), 50.0);
}
#[test]
fn test_chunked_prefill_stats_chunking_rate() {
let stats = ChunkedPrefillStats {
chunked_sequences: 3,
bypassed_sequences: 7,
..Default::default()
};
assert!((stats.chunking_rate() - 0.3).abs() < 0.01);
}
#[test]
fn test_chunked_prefill_scheduler_new() {
let scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
assert_eq!(scheduler.queue_len(), 0);
assert_eq!(scheduler.pending_count(), 0);
}
#[test]
fn test_chunked_prefill_scheduler_submit_short() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
let (seq_id, use_chunking) = scheduler.submit(100); assert_eq!(seq_id, 0);
assert!(!use_chunking);
assert_eq!(scheduler.stats().bypassed_sequences, 1);
assert_eq!(scheduler.stats().chunked_sequences, 0);
}
#[test]
fn test_chunked_prefill_scheduler_submit_long() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
let (seq_id, use_chunking) = scheduler.submit(1000); assert_eq!(seq_id, 0);
assert!(use_chunking);
assert_eq!(scheduler.stats().chunked_sequences, 1);
assert_eq!(scheduler.queue_len(), 1);
}
#[test]
fn test_chunked_prefill_scheduler_next_chunk() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
scheduler.submit(1000);
let chunk = scheduler.next_chunk();
assert!(chunk.is_some());
let (seq_id, range) = chunk.unwrap();
assert_eq!(seq_id, 0);
assert_eq!(range, 0..512);
}
#[test]
fn test_chunked_prefill_scheduler_complete_chunk() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
scheduler.submit(1000);
scheduler.complete_chunk(0, 512, 50);
assert_eq!(scheduler.stats().chunks_processed, 1);
assert_eq!(scheduler.stats().total_chunk_latency_ms, 50);
assert_eq!(scheduler.stats().max_chunk_latency_ms, 50);
let state = scheduler.get_state(0).unwrap();
assert_eq!(state.processed_tokens, 512);
}
#[test]
fn test_chunked_prefill_scheduler_full_prefill() {
let mut scheduler =
ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default().with_chunk_size(512));
scheduler.submit(1000);
let (seq_id, range) = scheduler.next_chunk().unwrap();
assert_eq!(range, 0..512);
scheduler.complete_chunk(seq_id, 512, 50);
let (seq_id, range) = scheduler.next_chunk().unwrap();
assert_eq!(range, 512..1000);
scheduler.complete_chunk(seq_id, 488, 40);
assert!(scheduler.next_chunk().is_none());
assert!(!scheduler.has_pending_prefill(0));
}
#[test]
fn test_chunked_prefill_scheduler_has_pending_prefill() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
assert!(!scheduler.has_pending_prefill(999));
scheduler.submit(1000);
assert!(scheduler.has_pending_prefill(0));
scheduler.complete_chunk(0, 1000, 100);
assert!(!scheduler.has_pending_prefill(0));
}
#[test]
fn test_chunked_prefill_scheduler_remove() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
scheduler.submit(1000);
scheduler.submit(2000);
assert_eq!(scheduler.queue_len(), 2);
let removed = scheduler.remove(0);
assert!(removed.is_some());
assert_eq!(scheduler.queue_len(), 1);
}
#[test]
fn test_chunked_prefill_scheduler_clear() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
scheduler.submit(1000);
scheduler.submit(2000);
scheduler.clear();
assert_eq!(scheduler.queue_len(), 0);
assert_eq!(scheduler.pending_count(), 0);
}
#[test]
fn test_chunked_prefill_scheduler_decode_interleave() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
assert!(!scheduler.should_interleave_decode());
scheduler.submit(1000);
assert!(scheduler.should_interleave_decode());
scheduler.record_decode_interleave();
assert_eq!(scheduler.stats().decode_interleaves, 1);
}
#[test]
fn test_chunked_prefill_scheduler_prefix_cache_hit() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::default());
scheduler.record_prefix_cache_hit(100);
assert_eq!(scheduler.stats().prefix_cache_hits, 100);
scheduler.record_prefix_cache_hit(50);
assert_eq!(scheduler.stats().prefix_cache_hits, 150);
}
#[test]
fn test_chunked_prefill_scheduler_disabled() {
let mut scheduler = ChunkedPrefillScheduler::new(ChunkedPrefillConfig::disabled());
let (_, use_chunking) = scheduler.submit(10000);
assert!(!use_chunking);
assert_eq!(scheduler.stats().bypassed_sequences, 1);
assert_eq!(scheduler.stats().chunked_sequences, 0);
}
#[test]
fn test_chunked_prefill_scheduler_default() {
let scheduler = ChunkedPrefillScheduler::default();
assert!(scheduler.config().enabled);
}
}