#![allow(clippy::must_use_candidate)]
#![allow(clippy::return_self_not_must_use)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::unnecessary_wraps)] #![allow(clippy::derivable_impls)] #![allow(clippy::option_if_let_else)]
use crate::paged_kv::{PagedCacheError, PagedKvCache, SeqId};
use serde::{Deserialize, Serialize};
use std::collections::{BinaryHeap, HashMap, VecDeque};
use std::time::Instant;
use thiserror::Error;
mod chunked_prefill;
mod types;
pub use chunked_prefill::{
ChunkedPrefillConfig, ChunkedPrefillScheduler, ChunkedPrefillState, ChunkedPrefillStats,
};
pub use types::{Priority, SchedulerStats, SequenceState};
#[derive(Debug, Error)]
pub enum SchedulerError {
#[error("Request queue full: capacity {capacity}")]
QueueFull {
capacity: usize,
},
#[error("Request not found: {0}")]
RequestNotFound(u64),
#[error("KV cache error: {0}")]
CacheError(#[from] PagedCacheError),
#[error("Invalid scheduler state: {0}")]
InvalidState(String),
}
#[derive(Debug, Clone)]
pub struct SchedulerRequest {
pub request_id: u64,
pub input_ids: Vec<u32>,
pub max_tokens: usize,
pub priority: Priority,
pub arrival_time: Instant,
pub seq_id: Option<SeqId>,
pub state: SequenceState,
pub generated_tokens: Vec<u32>,
pub iterations: usize,
}
impl SchedulerRequest {
pub fn new(request_id: u64, input_ids: Vec<u32>, max_tokens: usize) -> Self {
Self {
request_id,
input_ids,
max_tokens,
priority: Priority::default(),
arrival_time: Instant::now(),
seq_id: None,
state: SequenceState::Waiting,
generated_tokens: Vec::new(),
iterations: 0,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn total_tokens(&self) -> usize {
self.input_ids.len() + self.generated_tokens.len()
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.generated_tokens.len())
}
pub fn is_complete(&self, eos_token: u32) -> bool {
self.generated_tokens.len() >= self.max_tokens
|| self.generated_tokens.last() == Some(&eos_token)
}
pub fn wait_time(&self) -> std::time::Duration {
self.arrival_time.elapsed()
}
}
#[derive(Debug)]
struct PriorityEntry {
priority: Priority,
arrival_time: Instant,
request_id: u64,
}
impl PartialEq for PriorityEntry {
fn eq(&self, other: &Self) -> bool {
self.request_id == other.request_id
}
}
impl Eq for PriorityEntry {}
impl PartialOrd for PriorityEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.priority.cmp(&other.priority) {
std::cmp::Ordering::Equal => other.arrival_time.cmp(&self.arrival_time),
other => other,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SchedulerOutput {
pub scheduled_seq_ids: Vec<SeqId>,
pub scheduled_request_ids: Vec<u64>,
pub preempted_seq_ids: Vec<SeqId>,
pub completed_request_ids: Vec<u64>,
pub num_prefill_tokens: usize,
pub num_decode_tokens: usize,
}
impl SchedulerOutput {
pub fn total_tokens(&self) -> usize {
self.num_prefill_tokens + self.num_decode_tokens
}
pub fn is_empty(&self) -> bool {
self.scheduled_seq_ids.is_empty()
}
}
pub struct Scheduler {
requests: HashMap<u64, SchedulerRequest>,
waiting_queue: BinaryHeap<PriorityEntry>,
running: Vec<u64>,
preempted: VecDeque<u64>,
max_batch_size: usize,
max_queue_size: usize,
max_tokens_per_batch: usize,
next_request_id: u64,
stats: SchedulerStats,
total_wait_time_ms: f64,
}
include!("mod_max_scheduler.rs");
include!("mod_num_slots_idle.rs");
include!("mod_max_default_batch.rs");
include!("mod_dynamic_scheduler.rs");