use std::collections::HashMap;
use std::time::Instant;
pub const PREFILL_CHUNK: usize = 512;
pub const MAX_DECODE_WAIT_MS: u64 = 100;
pub type SeqId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SeqState {
Waiting,
Prefilling,
Decoding,
Finished,
Preempted,
}
#[derive(Debug)]
pub struct Sequence {
pub id: SeqId,
pub state: SeqState,
pub prompt_tokens: Vec<u32>,
pub output_tokens: Vec<u32>,
pub prompt_pos: usize,
pub max_tokens: usize,
pub stopped: bool,
pub priority: u64,
pub slot_id: Option<usize>,
pub prefill_progress: usize,
pub prefill_total: usize,
pub last_emit_time: Instant,
}
impl Sequence {
pub fn new(id: SeqId, prompt_tokens: Vec<u32>, max_tokens: usize) -> Self {
let total = prompt_tokens.len();
Self {
id,
state: SeqState::Waiting,
prompt_tokens,
output_tokens: Vec::new(),
prompt_pos: 0,
max_tokens,
stopped: false,
priority: id, slot_id: None,
prefill_progress: 0,
prefill_total: total,
last_emit_time: Instant::now(),
}
}
pub fn decode_wait_exceeded(&self) -> bool {
self.state == SeqState::Decoding
&& !self.stopped
&& self.last_emit_time.elapsed().as_millis() as u64 > MAX_DECODE_WAIT_MS
}
pub fn prefill_fraction(&self) -> f32 {
if self.prefill_total == 0 {
1.0
} else {
self.prefill_progress as f32 / self.prefill_total as f32
}
}
pub fn advance_prefill(&mut self, n: usize) {
self.prefill_progress += n;
}
pub fn total_tokens(&self) -> usize {
self.prompt_tokens.len() + self.output_tokens.len()
}
pub fn at_limit(&self) -> bool {
self.total_tokens() >= self.max_tokens
}
pub fn all_tokens(&self) -> Vec<u32> {
let mut tokens = self.prompt_tokens.clone();
tokens.extend_from_slice(&self.output_tokens);
tokens
}
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub max_sequences: usize,
pub max_batch_tokens: usize,
pub max_batch_sequences: usize,
pub max_prefill_tokens: usize,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
max_sequences: 32,
max_batch_tokens: 512,
max_batch_sequences: 8,
max_prefill_tokens: 256,
}
}
}
#[derive(Debug)]
pub struct ScheduledBatch {
pub seq_ids: Vec<SeqId>,
pub tokens: Vec<Vec<u32>>,
pub is_prefill: Vec<bool>,
}
impl ScheduledBatch {
pub fn total_tokens(&self) -> usize {
self.tokens.iter().map(|t| t.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.seq_ids.is_empty()
}
}
pub struct Scheduler {
config: SchedulerConfig,
sequences: HashMap<SeqId, Sequence>,
next_id: SeqId,
waiting_queue: Vec<SeqId>,
active_ids: Vec<SeqId>,
}
impl Scheduler {
pub fn new(config: SchedulerConfig) -> Self {
Self {
config,
sequences: HashMap::new(),
next_id: 1,
waiting_queue: Vec::new(),
active_ids: Vec::new(),
}
}
pub fn add_request(&mut self, prompt_tokens: Vec<u32>, max_tokens: usize) -> SeqId {
let id = self.next_id;
self.next_id += 1;
let seq = Sequence::new(id, prompt_tokens, max_tokens);
self.sequences.insert(id, seq);
self.waiting_queue.push(id);
id
}
pub fn remove_sequence(&mut self, id: SeqId) {
self.sequences.remove(&id);
self.waiting_queue.retain(|&x| x != id);
self.active_ids.retain(|&x| x != id);
}
pub fn finish_sequence(&mut self, id: SeqId) {
if let Some(seq) = self.sequences.get_mut(&id) {
seq.state = SeqState::Finished;
seq.stopped = true;
}
self.active_ids.retain(|&x| x != id);
}
pub fn append_token(&mut self, id: SeqId, token: u32) {
if let Some(seq) = self.sequences.get_mut(&id) {
seq.output_tokens.push(token);
seq.last_emit_time = Instant::now();
}
}
pub fn get_sequence(&self, id: SeqId) -> Option<&Sequence> {
self.sequences.get(&id)
}
pub fn active_count(&self) -> usize {
self.active_ids.len()
}
pub fn waiting_count(&self) -> usize {
self.waiting_queue.len()
}
pub fn total_count(&self) -> usize {
self.sequences.len()
}
pub fn has_work(&self) -> bool {
!self.waiting_queue.is_empty() || !self.active_ids.is_empty()
}
pub fn schedule(&mut self) -> ScheduledBatch {
let mut batch = ScheduledBatch {
seq_ids: Vec::new(),
tokens: Vec::new(),
is_prefill: Vec::new(),
};
while !self.waiting_queue.is_empty()
&& self.active_ids.len() < self.config.max_sequences
&& batch.seq_ids.len() < self.config.max_batch_sequences
{
let id = self.waiting_queue.remove(0);
if let Some(seq) = self.sequences.get_mut(&id) {
seq.state = SeqState::Prefilling;
self.active_ids.push(id);
let remaining = seq.prompt_tokens.len() - seq.prompt_pos;
let chunk = remaining.min(self.config.max_prefill_tokens);
let prefill_tokens =
seq.prompt_tokens[seq.prompt_pos..seq.prompt_pos + chunk].to_vec();
seq.prompt_pos += chunk;
if seq.prompt_pos >= seq.prompt_tokens.len() {
seq.state = SeqState::Decoding;
}
batch.seq_ids.push(id);
batch.tokens.push(prefill_tokens);
batch.is_prefill.push(true);
}
}
let active_snapshot: Vec<SeqId> = self.active_ids.clone();
for &id in &active_snapshot {
if batch.total_tokens() >= self.config.max_batch_tokens {
break;
}
if batch.seq_ids.contains(&id) {
continue;
}
if let Some(seq) = self.sequences.get_mut(&id) {
if seq.state == SeqState::Prefilling {
let remaining = seq.prompt_tokens.len() - seq.prompt_pos;
let budget = self.config.max_batch_tokens - batch.total_tokens();
let chunk = remaining.min(self.config.max_prefill_tokens).min(budget);
if chunk > 0 {
let prefill_tokens =
seq.prompt_tokens[seq.prompt_pos..seq.prompt_pos + chunk].to_vec();
seq.prompt_pos += chunk;
if seq.prompt_pos >= seq.prompt_tokens.len() {
seq.state = SeqState::Decoding;
}
batch.seq_ids.push(id);
batch.tokens.push(prefill_tokens);
batch.is_prefill.push(true);
}
}
}
}
for &id in &active_snapshot {
if batch.seq_ids.len() >= self.config.max_batch_sequences {
break;
}
if batch.total_tokens() >= self.config.max_batch_tokens {
break;
}
if let Some(seq) = self.sequences.get(&id) {
if seq.state == SeqState::Decoding
&& !seq.stopped
&& !seq.at_limit()
&& !batch.seq_ids.contains(&id)
{
let last_token = seq
.output_tokens
.last()
.copied()
.unwrap_or_else(|| *seq.prompt_tokens.last().unwrap_or(&0));
batch.seq_ids.push(id);
batch.tokens.push(vec![last_token]);
batch.is_prefill.push(false);
}
}
}
self.active_ids.retain(|&id| {
self.sequences
.get(&id)
.is_some_and(|s| !s.stopped && !s.at_limit() && s.state != SeqState::Finished)
});
batch
}
pub fn drain_finished(&mut self) -> Vec<Sequence> {
let finished_ids: Vec<SeqId> = self
.sequences
.iter()
.filter(|(_, s)| s.state == SeqState::Finished || s.stopped || s.at_limit())
.map(|(&id, _)| id)
.collect();
let mut finished = Vec::new();
for id in finished_ids {
if let Some(seq) = self.sequences.remove(&id) {
finished.push(seq);
}
self.active_ids.retain(|&x| x != id);
self.waiting_queue.retain(|&x| x != id);
}
finished
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_and_schedule_single() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1, 2, 3], 10);
assert_eq!(scheduler.waiting_count(), 1);
assert_eq!(scheduler.active_count(), 0);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 1);
assert_eq!(batch.seq_ids[0], id);
assert_eq!(batch.tokens[0], vec![1, 2, 3]);
assert!(batch.is_prefill[0]);
assert_eq!(scheduler.waiting_count(), 0);
assert_eq!(scheduler.active_count(), 1);
}
#[test]
fn test_decode_after_prefill() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1, 2, 3], 10);
let batch = scheduler.schedule();
assert!(batch.is_prefill[0]);
scheduler.append_token(id, 4);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 1);
assert!(!batch.is_prefill[0]);
assert_eq!(batch.tokens[0], vec![4]); }
#[test]
fn test_finish_sequence() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1], 5);
scheduler.schedule();
scheduler.finish_sequence(id);
let batch = scheduler.schedule();
assert!(batch.is_empty());
assert_eq!(scheduler.active_count(), 0);
}
#[test]
fn test_multiple_sequences() {
let config = SchedulerConfig {
max_batch_sequences: 4,
..SchedulerConfig::default()
};
let mut scheduler = Scheduler::new(config);
let id1 = scheduler.add_request(vec![1, 2], 10);
let id2 = scheduler.add_request(vec![3, 4], 10);
let id3 = scheduler.add_request(vec![5, 6], 10);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 3);
assert!(batch.seq_ids.contains(&id1));
assert!(batch.seq_ids.contains(&id2));
assert!(batch.seq_ids.contains(&id3));
}
#[test]
fn test_max_sequences_respected() {
let config = SchedulerConfig {
max_sequences: 2,
..SchedulerConfig::default()
};
let mut scheduler = Scheduler::new(config);
scheduler.add_request(vec![1], 10);
scheduler.add_request(vec![2], 10);
scheduler.add_request(vec![3], 10);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 2);
assert_eq!(scheduler.waiting_count(), 1);
}
#[test]
fn test_remove_sequence() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1, 2, 3], 10);
assert_eq!(scheduler.total_count(), 1);
scheduler.remove_sequence(id);
assert_eq!(scheduler.total_count(), 0);
assert_eq!(scheduler.waiting_count(), 0);
}
#[test]
fn test_at_limit_stops_scheduling() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1], 3);
scheduler.schedule(); scheduler.append_token(id, 2);
scheduler.schedule(); scheduler.append_token(id, 3);
let batch = scheduler.schedule();
let has_id = batch.seq_ids.contains(&id);
assert!(!has_id, "at-limit sequence should not be scheduled");
}
#[test]
fn test_drain_finished() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id1 = scheduler.add_request(vec![1], 10);
let id2 = scheduler.add_request(vec![2], 10);
scheduler.schedule();
scheduler.finish_sequence(id1);
let finished = scheduler.drain_finished();
assert_eq!(finished.len(), 1);
assert_eq!(finished[0].id, id1);
assert_eq!(scheduler.total_count(), 1);
assert!(scheduler.get_sequence(id2).is_some());
}
#[test]
fn test_long_prefill_chunked() {
let config = SchedulerConfig {
max_prefill_tokens: 4,
..SchedulerConfig::default()
};
let mut scheduler = Scheduler::new(config);
scheduler.add_request(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 20);
let batch = scheduler.schedule();
assert_eq!(batch.tokens[0], vec![1, 2, 3, 4]);
assert!(batch.is_prefill[0]);
let batch = scheduler.schedule();
assert_eq!(batch.tokens[0].len(), 4);
let batch = scheduler.schedule();
assert_eq!(batch.tokens[0].len(), 2);
}
#[test]
fn test_has_work() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
assert!(!scheduler.has_work());
scheduler.add_request(vec![1], 5);
assert!(scheduler.has_work());
}
#[test]
fn chunked_prefill_reports_progress() {
let prompt: Vec<u32> = (1..=10).collect();
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(prompt.clone(), 20);
{
let seq = scheduler.get_sequence(id).expect("sequence must exist");
assert_eq!(seq.prefill_progress, 0, "progress starts at 0");
assert_eq!(seq.prefill_total, 10, "total equals prompt length");
assert_eq!(seq.prefill_fraction(), 0.0, "fraction starts at 0.0");
}
if let Some(seq) = scheduler.sequences.get_mut(&id) {
seq.advance_prefill(5);
}
{
let seq = scheduler.get_sequence(id).expect("sequence must exist");
assert_eq!(seq.prefill_progress, 5);
assert!(
(seq.prefill_fraction() - 0.5).abs() < 1e-6,
"half progress → fraction 0.5"
);
}
if let Some(seq) = scheduler.sequences.get_mut(&id) {
seq.advance_prefill(5);
}
{
let seq = scheduler.get_sequence(id).expect("sequence must exist");
assert_eq!(seq.prefill_progress, 10);
assert!(
(seq.prefill_fraction() - 1.0).abs() < 1e-6,
"full progress → fraction 1.0"
);
}
}
#[test]
fn chunked_prefill_kv_matches_singleshot() {
let prompt: Vec<u32> = (1..=8).collect();
let chunk = 4usize;
let config = SchedulerConfig {
max_prefill_tokens: chunk,
..SchedulerConfig::default()
};
let mut sched = Scheduler::new(config);
sched.add_request(prompt.clone(), 20);
let mut all_prefill_tokens: Vec<u32> = Vec::new();
for _ in 0..4 {
let batch = sched.schedule();
if batch.is_empty() {
break;
}
for (i, &is_pf) in batch.is_prefill.iter().enumerate() {
if is_pf {
all_prefill_tokens.extend_from_slice(&batch.tokens[i]);
}
}
}
assert_eq!(
all_prefill_tokens, prompt,
"chunked prefill must tile the full prompt without gaps or overlaps"
);
}
#[test]
fn decode_wait_exceeded_false_initially() {
let mut sched = Scheduler::new(SchedulerConfig::default());
let id = sched.add_request(vec![1, 2], 10);
sched.schedule();
if let Some(seq) = sched.sequences.get_mut(&id) {
seq.state = SeqState::Decoding;
}
let seq = sched.get_sequence(id).expect("must exist");
assert!(
!seq.decode_wait_exceeded(),
"newly-created sequence must not exceed decode wait immediately"
);
}
#[test]
fn advance_prefill_is_independent_of_prompt_pos() {
let prompt: Vec<u32> = vec![1, 2, 3, 4];
let mut sched = Scheduler::new(SchedulerConfig::default());
let id = sched.add_request(prompt, 10);
let initial_prompt_pos = sched.get_sequence(id).expect("must exist").prompt_pos;
if let Some(seq) = sched.sequences.get_mut(&id) {
seq.advance_prefill(2);
}
let seq = sched.get_sequence(id).expect("must exist");
assert_eq!(
seq.prompt_pos, initial_prompt_pos,
"prompt_pos must be unchanged by advance_prefill"
);
assert_eq!(
seq.prefill_progress, 2,
"prefill_progress must advance by 2"
);
}
#[test]
fn append_token_refreshes_last_emit_time() {
let mut sched = Scheduler::new(SchedulerConfig::default());
let id = sched.add_request(vec![1], 10);
let t_before = sched.get_sequence(id).expect("must exist").last_emit_time;
std::thread::sleep(std::time::Duration::from_millis(2));
sched.append_token(id, 99);
let t_after = sched.get_sequence(id).expect("must exist").last_emit_time;
assert!(
t_after >= t_before,
"last_emit_time must not move backwards after append_token"
);
}
#[test]
fn prefill_fraction_one_for_empty_prompt() {
let mut sched = Scheduler::new(SchedulerConfig::default());
let id = sched.add_request(vec![], 10);
let seq = sched.get_sequence(id).expect("must exist");
assert!(
(seq.prefill_fraction() - 1.0).abs() < 1e-6,
"empty prompt prefill_fraction must be 1.0"
);
}
#[test]
fn prefill_fairness_constants() {
assert_eq!(PREFILL_CHUNK, 512);
assert_eq!(MAX_DECODE_WAIT_MS, 100);
}
}