use std::collections::HashMap;
pub type SeqId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SeqState {
Waiting,
Prefilling,
Decoding,
Finished,
Preempted,
}
#[derive(Debug)]
pub struct Sequence {
pub id: SeqId,
pub state: SeqState,
pub prompt_tokens: Vec<u32>,
pub output_tokens: Vec<u32>,
pub prompt_pos: usize,
pub max_tokens: usize,
pub stopped: bool,
pub priority: u64,
}
impl Sequence {
pub fn new(id: SeqId, prompt_tokens: Vec<u32>, max_tokens: usize) -> Self {
Self {
id,
state: SeqState::Waiting,
prompt_tokens,
output_tokens: Vec::new(),
prompt_pos: 0,
max_tokens,
stopped: false,
priority: id, }
}
pub fn total_tokens(&self) -> usize {
self.prompt_tokens.len() + self.output_tokens.len()
}
pub fn at_limit(&self) -> bool {
self.total_tokens() >= self.max_tokens
}
pub fn all_tokens(&self) -> Vec<u32> {
let mut tokens = self.prompt_tokens.clone();
tokens.extend_from_slice(&self.output_tokens);
tokens
}
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub max_sequences: usize,
pub max_batch_tokens: usize,
pub max_batch_sequences: usize,
pub max_prefill_tokens: usize,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
max_sequences: 32,
max_batch_tokens: 512,
max_batch_sequences: 8,
max_prefill_tokens: 256,
}
}
}
#[derive(Debug)]
pub struct ScheduledBatch {
pub seq_ids: Vec<SeqId>,
pub tokens: Vec<Vec<u32>>,
pub is_prefill: Vec<bool>,
}
impl ScheduledBatch {
pub fn total_tokens(&self) -> usize {
self.tokens.iter().map(|t| t.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.seq_ids.is_empty()
}
}
pub struct Scheduler {
config: SchedulerConfig,
sequences: HashMap<SeqId, Sequence>,
next_id: SeqId,
waiting_queue: Vec<SeqId>,
active_ids: Vec<SeqId>,
}
impl Scheduler {
pub fn new(config: SchedulerConfig) -> Self {
Self {
config,
sequences: HashMap::new(),
next_id: 1,
waiting_queue: Vec::new(),
active_ids: Vec::new(),
}
}
pub fn add_request(&mut self, prompt_tokens: Vec<u32>, max_tokens: usize) -> SeqId {
let id = self.next_id;
self.next_id += 1;
let seq = Sequence::new(id, prompt_tokens, max_tokens);
self.sequences.insert(id, seq);
self.waiting_queue.push(id);
id
}
pub fn remove_sequence(&mut self, id: SeqId) {
self.sequences.remove(&id);
self.waiting_queue.retain(|&x| x != id);
self.active_ids.retain(|&x| x != id);
}
pub fn finish_sequence(&mut self, id: SeqId) {
if let Some(seq) = self.sequences.get_mut(&id) {
seq.state = SeqState::Finished;
seq.stopped = true;
}
self.active_ids.retain(|&x| x != id);
}
pub fn append_token(&mut self, id: SeqId, token: u32) {
if let Some(seq) = self.sequences.get_mut(&id) {
seq.output_tokens.push(token);
}
}
pub fn get_sequence(&self, id: SeqId) -> Option<&Sequence> {
self.sequences.get(&id)
}
pub fn active_count(&self) -> usize {
self.active_ids.len()
}
pub fn waiting_count(&self) -> usize {
self.waiting_queue.len()
}
pub fn total_count(&self) -> usize {
self.sequences.len()
}
pub fn has_work(&self) -> bool {
!self.waiting_queue.is_empty() || !self.active_ids.is_empty()
}
pub fn schedule(&mut self) -> ScheduledBatch {
let mut batch = ScheduledBatch {
seq_ids: Vec::new(),
tokens: Vec::new(),
is_prefill: Vec::new(),
};
while !self.waiting_queue.is_empty()
&& self.active_ids.len() < self.config.max_sequences
&& batch.seq_ids.len() < self.config.max_batch_sequences
{
let id = self.waiting_queue.remove(0);
if let Some(seq) = self.sequences.get_mut(&id) {
seq.state = SeqState::Prefilling;
self.active_ids.push(id);
let remaining = seq.prompt_tokens.len() - seq.prompt_pos;
let chunk = remaining.min(self.config.max_prefill_tokens);
let prefill_tokens =
seq.prompt_tokens[seq.prompt_pos..seq.prompt_pos + chunk].to_vec();
seq.prompt_pos += chunk;
if seq.prompt_pos >= seq.prompt_tokens.len() {
seq.state = SeqState::Decoding;
}
batch.seq_ids.push(id);
batch.tokens.push(prefill_tokens);
batch.is_prefill.push(true);
}
}
let active_snapshot: Vec<SeqId> = self.active_ids.clone();
for &id in &active_snapshot {
if batch.total_tokens() >= self.config.max_batch_tokens {
break;
}
if batch.seq_ids.contains(&id) {
continue;
}
if let Some(seq) = self.sequences.get_mut(&id) {
if seq.state == SeqState::Prefilling {
let remaining = seq.prompt_tokens.len() - seq.prompt_pos;
let budget = self.config.max_batch_tokens - batch.total_tokens();
let chunk = remaining.min(self.config.max_prefill_tokens).min(budget);
if chunk > 0 {
let prefill_tokens =
seq.prompt_tokens[seq.prompt_pos..seq.prompt_pos + chunk].to_vec();
seq.prompt_pos += chunk;
if seq.prompt_pos >= seq.prompt_tokens.len() {
seq.state = SeqState::Decoding;
}
batch.seq_ids.push(id);
batch.tokens.push(prefill_tokens);
batch.is_prefill.push(true);
}
}
}
}
for &id in &active_snapshot {
if batch.seq_ids.len() >= self.config.max_batch_sequences {
break;
}
if batch.total_tokens() >= self.config.max_batch_tokens {
break;
}
if let Some(seq) = self.sequences.get(&id) {
if seq.state == SeqState::Decoding
&& !seq.stopped
&& !seq.at_limit()
&& !batch.seq_ids.contains(&id)
{
let last_token = seq
.output_tokens
.last()
.copied()
.unwrap_or_else(|| *seq.prompt_tokens.last().unwrap_or(&0));
batch.seq_ids.push(id);
batch.tokens.push(vec![last_token]);
batch.is_prefill.push(false);
}
}
}
self.active_ids.retain(|&id| {
self.sequences
.get(&id)
.is_some_and(|s| !s.stopped && !s.at_limit() && s.state != SeqState::Finished)
});
batch
}
pub fn drain_finished(&mut self) -> Vec<Sequence> {
let finished_ids: Vec<SeqId> = self
.sequences
.iter()
.filter(|(_, s)| s.state == SeqState::Finished || s.stopped || s.at_limit())
.map(|(&id, _)| id)
.collect();
let mut finished = Vec::new();
for id in finished_ids {
if let Some(seq) = self.sequences.remove(&id) {
finished.push(seq);
}
self.active_ids.retain(|&x| x != id);
self.waiting_queue.retain(|&x| x != id);
}
finished
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_and_schedule_single() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1, 2, 3], 10);
assert_eq!(scheduler.waiting_count(), 1);
assert_eq!(scheduler.active_count(), 0);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 1);
assert_eq!(batch.seq_ids[0], id);
assert_eq!(batch.tokens[0], vec![1, 2, 3]);
assert!(batch.is_prefill[0]);
assert_eq!(scheduler.waiting_count(), 0);
assert_eq!(scheduler.active_count(), 1);
}
#[test]
fn test_decode_after_prefill() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1, 2, 3], 10);
let batch = scheduler.schedule();
assert!(batch.is_prefill[0]);
scheduler.append_token(id, 4);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 1);
assert!(!batch.is_prefill[0]);
assert_eq!(batch.tokens[0], vec![4]); }
#[test]
fn test_finish_sequence() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1], 5);
scheduler.schedule();
scheduler.finish_sequence(id);
let batch = scheduler.schedule();
assert!(batch.is_empty());
assert_eq!(scheduler.active_count(), 0);
}
#[test]
fn test_multiple_sequences() {
let config = SchedulerConfig {
max_batch_sequences: 4,
..SchedulerConfig::default()
};
let mut scheduler = Scheduler::new(config);
let id1 = scheduler.add_request(vec![1, 2], 10);
let id2 = scheduler.add_request(vec![3, 4], 10);
let id3 = scheduler.add_request(vec![5, 6], 10);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 3);
assert!(batch.seq_ids.contains(&id1));
assert!(batch.seq_ids.contains(&id2));
assert!(batch.seq_ids.contains(&id3));
}
#[test]
fn test_max_sequences_respected() {
let config = SchedulerConfig {
max_sequences: 2,
..SchedulerConfig::default()
};
let mut scheduler = Scheduler::new(config);
scheduler.add_request(vec![1], 10);
scheduler.add_request(vec![2], 10);
scheduler.add_request(vec![3], 10);
let batch = scheduler.schedule();
assert_eq!(batch.seq_ids.len(), 2);
assert_eq!(scheduler.waiting_count(), 1);
}
#[test]
fn test_remove_sequence() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1, 2, 3], 10);
assert_eq!(scheduler.total_count(), 1);
scheduler.remove_sequence(id);
assert_eq!(scheduler.total_count(), 0);
assert_eq!(scheduler.waiting_count(), 0);
}
#[test]
fn test_at_limit_stops_scheduling() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id = scheduler.add_request(vec![1], 3);
scheduler.schedule(); scheduler.append_token(id, 2);
scheduler.schedule(); scheduler.append_token(id, 3);
let batch = scheduler.schedule();
let has_id = batch.seq_ids.contains(&id);
assert!(!has_id, "at-limit sequence should not be scheduled");
}
#[test]
fn test_drain_finished() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
let id1 = scheduler.add_request(vec![1], 10);
let id2 = scheduler.add_request(vec![2], 10);
scheduler.schedule();
scheduler.finish_sequence(id1);
let finished = scheduler.drain_finished();
assert_eq!(finished.len(), 1);
assert_eq!(finished[0].id, id1);
assert_eq!(scheduler.total_count(), 1);
assert!(scheduler.get_sequence(id2).is_some());
}
#[test]
fn test_long_prefill_chunked() {
let config = SchedulerConfig {
max_prefill_tokens: 4,
..SchedulerConfig::default()
};
let mut scheduler = Scheduler::new(config);
scheduler.add_request(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 20);
let batch = scheduler.schedule();
assert_eq!(batch.tokens[0], vec![1, 2, 3, 4]);
assert!(batch.is_prefill[0]);
let batch = scheduler.schedule();
assert_eq!(batch.tokens[0].len(), 4);
let batch = scheduler.schedule();
assert_eq!(batch.tokens[0].len(), 2);
}
#[test]
fn test_has_work() {
let mut scheduler = Scheduler::new(SchedulerConfig::default());
assert!(!scheduler.has_work());
scheduler.add_request(vec![1], 5);
assert!(scheduler.has_work());
}
}