#![allow(clippy::too_many_arguments)]
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkedPrefillConfig {
pub enabled: bool,
pub chunk_size: usize,
pub min_prompt_length: usize,
pub allow_decode_interleave: bool,
pub boost_partial_prefill: bool,
pub max_chunks: usize,
}
impl Default for ChunkedPrefillConfig {
fn default() -> Self {
Self {
enabled: true,
chunk_size: 512,
min_prompt_length: 256,
allow_decode_interleave: true,
boost_partial_prefill: true,
max_chunks: 16,
}
}
}
impl ChunkedPrefillConfig {
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn low_latency() -> Self {
Self {
enabled: true,
chunk_size: 128,
min_prompt_length: 64,
allow_decode_interleave: true,
boost_partial_prefill: true,
max_chunks: 32,
}
}
pub fn high_throughput() -> Self {
Self {
enabled: true,
chunk_size: 1024,
min_prompt_length: 512,
allow_decode_interleave: false,
boost_partial_prefill: false,
max_chunks: 8,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkedPrefillState {
pub seq_id: u64,
pub total_tokens: usize,
pub processed_tokens: usize,
pub current_chunk: usize,
pub total_chunks: usize,
pub start_time_ms: u64,
pub chunk_latencies: Vec<u64>,
}
impl ChunkedPrefillState {
pub fn new(seq_id: u64, total_tokens: usize, chunk_size: usize) -> Self {
let total_chunks = total_tokens.div_ceil(chunk_size);
Self {
seq_id,
total_tokens,
processed_tokens: 0,
current_chunk: 0,
total_chunks,
start_time_ms: 0,
chunk_latencies: Vec::with_capacity(total_chunks),
}
}
pub fn next_chunk(&self, chunk_size: usize) -> std::ops::Range<usize> {
let start = self.processed_tokens;
let end = (start + chunk_size).min(self.total_tokens);
start..end
}
pub fn advance(&mut self, tokens_processed: usize, latency_ms: u64) {
self.processed_tokens += tokens_processed;
self.current_chunk += 1;
self.chunk_latencies.push(latency_ms);
}
pub fn is_complete(&self) -> bool {
self.processed_tokens >= self.total_tokens
}
pub fn progress(&self) -> f64 {
if self.total_tokens == 0 {
100.0
} else {
(self.processed_tokens as f64 / self.total_tokens as f64) * 100.0
}
}
pub fn remaining_tokens(&self) -> usize {
self.total_tokens.saturating_sub(self.processed_tokens)
}
pub fn avg_chunk_latency_ms(&self) -> f64 {
if self.chunk_latencies.is_empty() {
0.0
} else {
self.chunk_latencies.iter().sum::<u64>() as f64 / self.chunk_latencies.len() as f64
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChunkedPrefillStats {
pub chunked_sequences: u64,
pub bypassed_sequences: u64,
pub chunks_processed: u64,
pub decode_interleaves: u64,
pub total_chunk_latency_ms: u64,
pub max_chunk_latency_ms: u64,
pub prefix_cache_hits: u64,
}
impl ChunkedPrefillStats {
pub fn avg_chunk_latency_ms(&self) -> f64 {
if self.chunks_processed == 0 {
0.0
} else {
self.total_chunk_latency_ms as f64 / self.chunks_processed as f64
}
}
pub fn chunking_rate(&self) -> f64 {
let total = self.chunked_sequences + self.bypassed_sequences;
if total == 0 {
0.0
} else {
self.chunked_sequences as f64 / total as f64
}
}
}
pub struct ChunkedPrefillScheduler {
config: ChunkedPrefillConfig,
active_prefills: HashMap<u64, ChunkedPrefillState>,
prefill_queue: VecDeque<u64>,
stats: ChunkedPrefillStats,
next_seq_id: u64,
}
impl ChunkedPrefillScheduler {
pub fn new(config: ChunkedPrefillConfig) -> Self {
Self {
config,
active_prefills: HashMap::new(),
prefill_queue: VecDeque::new(),
stats: ChunkedPrefillStats::default(),
next_seq_id: 0,
}
}
pub fn submit(&mut self, prompt_tokens: usize) -> (u64, bool) {
let seq_id = self.next_seq_id;
self.next_seq_id += 1;
let use_chunking = self.config.enabled && prompt_tokens >= self.config.min_prompt_length;
if use_chunking {
let state = ChunkedPrefillState::new(seq_id, prompt_tokens, self.config.chunk_size);
self.active_prefills.insert(seq_id, state);
self.prefill_queue.push_back(seq_id);
self.stats.chunked_sequences += 1;
} else {
self.stats.bypassed_sequences += 1;
}
(seq_id, use_chunking)
}
pub fn next_chunk(&mut self) -> Option<(u64, std::ops::Range<usize>)> {
while let Some(&seq_id) = self.prefill_queue.front() {
if let Some(state) = self.active_prefills.get(&seq_id) {
if !state.is_complete() {
let range = state.next_chunk(self.config.chunk_size);
return Some((seq_id, range));
}
}
self.prefill_queue.pop_front();
}
None
}
pub fn complete_chunk(&mut self, seq_id: u64, tokens_processed: usize, latency_ms: u64) {
if let Some(state) = self.active_prefills.get_mut(&seq_id) {
state.advance(tokens_processed, latency_ms);
self.stats.chunks_processed += 1;
self.stats.total_chunk_latency_ms += latency_ms;
self.stats.max_chunk_latency_ms = self.stats.max_chunk_latency_ms.max(latency_ms);
if state.is_complete() {
if let Some(pos) = self.prefill_queue.iter().position(|&id| id == seq_id) {
self.prefill_queue.remove(pos);
}
} else if self.config.boost_partial_prefill {
} else {
if let Some(pos) = self.prefill_queue.iter().position(|&id| id == seq_id) {
self.prefill_queue.remove(pos);
self.prefill_queue.push_back(seq_id);
}
}
}
}
pub fn record_decode_interleave(&mut self) {
self.stats.decode_interleaves += 1;
}
pub fn should_interleave_decode(&self) -> bool {
self.config.allow_decode_interleave && !self.prefill_queue.is_empty()
}
pub fn get_state(&self, seq_id: u64) -> Option<&ChunkedPrefillState> {
self.active_prefills.get(&seq_id)
}
pub fn has_pending_prefill(&self, seq_id: u64) -> bool {
self.active_prefills
.get(&seq_id)
.is_some_and(|s| !s.is_complete())
}
pub fn remove(&mut self, seq_id: u64) -> Option<ChunkedPrefillState> {
if let Some(pos) = self.prefill_queue.iter().position(|&id| id == seq_id) {
self.prefill_queue.remove(pos);
}
self.active_prefills.remove(&seq_id)
}
pub fn pending_count(&self) -> usize {
self.active_prefills
.values()
.filter(|s| !s.is_complete())
.count()
}
pub fn queue_len(&self) -> usize {
self.prefill_queue.len()
}
pub fn stats(&self) -> &ChunkedPrefillStats {
&self.stats
}
pub fn config(&self) -> &ChunkedPrefillConfig {
&self.config
}
pub fn clear(&mut self) {
self.active_prefills.clear();
self.prefill_queue.clear();
}
pub fn record_prefix_cache_hit(&mut self, tokens_saved: usize) {
self.stats.prefix_cache_hits += tokens_saved as u64;
}
}
impl Default for ChunkedPrefillScheduler {
fn default() -> Self {
Self::new(ChunkedPrefillConfig::default())
}
}
include!("chunked_prefill_config_default.rs");