impl Default for BatchConfig {
fn default() -> Self {
Self {
max_ubatch_tokens: 512,
max_sbatch_sequences: 8,
prefer_pure_decode: true,
max_context_length: 2048,
dynamic_batching: true,
}
}
}
impl BatchConfig {
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_ubatch_tokens = max_tokens;
self
}
pub fn with_max_sequences(mut self, max_seqs: usize) -> Self {
self.max_sbatch_sequences = max_seqs;
self
}
}
pub struct BatchScheduler {
config: BatchConfig,
sbatch: SequenceBatch,
next_seq_idx: usize,
stats: BatchStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatchStats {
pub ubatches_created: u64,
pub sbatches_created: u64,
pub tokens_processed: u64,
pub prefill_tokens: u64,
pub decode_tokens: u64,
pub avg_ubatch_size: f64,
pub avg_sbatch_size: f64,
}
impl BatchScheduler {
pub fn new() -> Self {
Self::with_config(BatchConfig::default())
}
pub fn with_config(config: BatchConfig) -> Self {
let max_seqs = config.max_sbatch_sequences;
Self {
config,
sbatch: SequenceBatch::new(max_seqs),
next_seq_idx: 0,
stats: BatchStats::default(),
}
}
pub fn add_sequence(
&mut self,
slot_id: usize,
request_id: u64,
input_tokens: Vec<u32>,
) -> Option<usize> {
if self.sbatch.is_full() {
return None;
}
let seq_idx = self.next_seq_idx;
self.next_seq_idx += 1;
let entry = SequenceBatchEntry::new(seq_idx, slot_id, request_id).with_tokens(input_tokens);
if self.sbatch.add_sequence(entry) {
Some(seq_idx)
} else {
None
}
}
pub fn complete_sequence(&mut self, seq_idx: usize) -> Option<SequenceBatchEntry> {
self.sbatch.remove_sequence(seq_idx)
}
pub fn start_decode(&mut self, seq_idx: usize, position: usize) -> bool {
if let Some(entry) = self.sbatch.get_mut(seq_idx) {
entry.is_prefill = false;
entry.position = position;
entry.tokens.clear(); true
} else {
false
}
}
pub fn create_ubatch(&mut self) -> MicroBatch {
let mut ubatch = MicroBatch::with_capacity(self.config.max_ubatch_tokens);
for entry in &self.sbatch.sequences {
if entry.is_prefill {
for (i, &token_id) in entry.tokens.iter().enumerate() {
if ubatch.len() >= self.config.max_ubatch_tokens {
break;
}
ubatch.add_token(BatchToken::new(token_id, entry.seq_idx, i, true));
}
}
}
if self.config.prefer_pure_decode && !ubatch.is_empty() && ubatch.is_prefill() {
self.record_ubatch(&ubatch);
return ubatch;
}
for entry in &self.sbatch.sequences {
if !entry.is_prefill {
if ubatch.len() >= self.config.max_ubatch_tokens {
break;
}
ubatch.add_token(BatchToken::new(
0, entry.seq_idx,
entry.position,
false,
));
}
}
self.record_ubatch(&ubatch);
ubatch
}
fn record_ubatch(&mut self, ubatch: &MicroBatch) {
if ubatch.is_empty() {
return;
}
self.stats.ubatches_created += 1;
self.stats.tokens_processed += ubatch.len() as u64;
self.stats.prefill_tokens += ubatch.n_prompt_tokens as u64;
self.stats.decode_tokens += ubatch.n_decode_tokens as u64;
let n = self.stats.ubatches_created as f64;
self.stats.avg_ubatch_size =
self.stats.avg_ubatch_size * (n - 1.0) / n + ubatch.len() as f64 / n;
}
pub fn sbatch(&self) -> &SequenceBatch {
&self.sbatch
}
pub fn stats(&self) -> &BatchStats {
&self.stats
}
pub fn config(&self) -> &BatchConfig {
&self.config
}
pub fn num_sequences(&self) -> usize {
self.sbatch.len()
}
pub fn has_capacity(&self) -> bool {
!self.sbatch.is_full()
}
pub fn utilization(&self) -> f64 {
self.sbatch.utilization
}
}
impl Default for BatchScheduler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Deadline {
pub target_latency_ms: u64,
pub hard_deadline_ms: Option<u64>,
pub sla_target: f64,
}
impl Default for Deadline {
fn default() -> Self {
Self {
target_latency_ms: 1000, hard_deadline_ms: None,
sla_target: 0.99, }
}
}
impl Deadline {
pub fn with_target(target_ms: u64) -> Self {
Self {
target_latency_ms: target_ms,
..Default::default()
}
}
pub fn strict(target_ms: u64, hard_ms: u64) -> Self {
Self {
target_latency_ms: target_ms,
hard_deadline_ms: Some(hard_ms),
sla_target: 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicPriorityConfig {
pub enable_age_promotion: bool,
pub promotion_interval_ms: u64,
pub max_promoted_priority: Priority,
pub priority_budgets: [f64; 4], pub enable_deadline_scheduling: bool,
pub urgency_factor: f64,
pub min_tokens_per_request: usize,
pub enable_fair_share: bool,
}
impl Default for DynamicPriorityConfig {
fn default() -> Self {
Self {
enable_age_promotion: true,
promotion_interval_ms: 5000, max_promoted_priority: Priority::High,
priority_budgets: [0.05, 0.30, 0.40, 0.25],
enable_deadline_scheduling: true,
urgency_factor: 2.0,
min_tokens_per_request: 1,
enable_fair_share: true,
}
}
}
impl DynamicPriorityConfig {
pub fn with_budgets(budgets: [f64; 4]) -> Self {
Self {
priority_budgets: budgets,
..Default::default()
}
}
pub fn no_promotion(mut self) -> Self {
self.enable_age_promotion = false;
self
}
pub fn with_promotion_interval(mut self, ms: u64) -> Self {
self.promotion_interval_ms = ms;
self
}
}
#[derive(Debug, Clone)]
pub struct DynamicRequest {
pub request_id: u64,
pub input_ids: Vec<u32>,
pub max_tokens: usize,
pub original_priority: Priority,
pub effective_priority: Priority,
pub arrival_time: Instant,
pub deadline: Option<Deadline>,
pub promotions: u32,
pub state: SequenceState,
pub generated_tokens: Vec<u32>,
pub seq_id: Option<SeqId>,
pub ttft_ms: Option<f64>,
}
impl DynamicRequest {
pub fn new(request_id: u64, input_ids: Vec<u32>, max_tokens: usize) -> Self {
Self {
request_id,
input_ids,
max_tokens,
original_priority: Priority::Normal,
effective_priority: Priority::Normal,
arrival_time: Instant::now(),
deadline: None,
promotions: 0,
state: SequenceState::Waiting,
generated_tokens: Vec::new(),
seq_id: None,
ttft_ms: None,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.original_priority = priority;
self.effective_priority = priority;
self
}
pub fn with_deadline(mut self, deadline: Deadline) -> Self {
self.deadline = Some(deadline);
self
}
pub fn wait_time_ms(&self) -> u64 {
self.arrival_time.elapsed().as_millis() as u64
}
pub fn is_urgent(&self) -> bool {
if let Some(deadline) = &self.deadline {
let elapsed = self.wait_time_ms();
elapsed >= deadline.target_latency_ms / 2
} else {
false
}
}
pub fn is_expired(&self) -> bool {
if let Some(deadline) = &self.deadline {
if let Some(hard) = deadline.hard_deadline_ms {
return self.wait_time_ms() > hard;
}
}
false
}
pub fn urgency_score(&self) -> f64 {
if let Some(deadline) = &self.deadline {
let elapsed = self.wait_time_ms() as f64;
let target = deadline.target_latency_ms as f64;
if target > 0.0 {
elapsed / target
} else {
0.0
}
} else {
0.0
}
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.generated_tokens.len())
}
pub fn total_tokens(&self) -> usize {
self.input_ids.len() + self.generated_tokens.len()
}
}