use super::request::{RequestId, RunningRequest};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct BatchedRequest {
pub request_id: RequestId,
pub token_ids: Vec<u32>,
pub position_offset: usize,
pub kv_cache_slot: usize,
pub block_table: Vec<usize>,
pub is_prefill: bool,
pub seq_len: usize,
pub context_len: usize,
}
impl BatchedRequest {
pub fn prefill(
request_id: RequestId,
token_ids: Vec<u32>,
kv_cache_slot: usize,
block_table: Vec<usize>,
) -> Self {
let seq_len = token_ids.len();
Self {
request_id,
token_ids,
position_offset: 0,
kv_cache_slot,
block_table,
is_prefill: true,
seq_len,
context_len: 0,
}
}
pub fn decode(
request_id: RequestId,
token_id: u32,
position_offset: usize,
kv_cache_slot: usize,
block_table: Vec<usize>,
context_len: usize,
) -> Self {
Self {
request_id,
token_ids: vec![token_id],
position_offset,
kv_cache_slot,
block_table,
is_prefill: false,
seq_len: context_len + 1,
context_len,
}
}
pub fn num_tokens(&self) -> usize {
self.token_ids.len()
}
}
#[derive(Debug)]
pub struct ScheduledBatch {
pub requests: Vec<BatchedRequest>,
pub total_tokens: usize,
pub has_prefill: bool,
pub has_decode: bool,
pub max_seq_len: usize,
pub batch_id: u64,
}
impl ScheduledBatch {
pub fn new(batch_id: u64) -> Self {
Self {
requests: Vec::new(),
total_tokens: 0,
has_prefill: false,
has_decode: false,
max_seq_len: 0,
batch_id,
}
}
pub fn add(&mut self, request: BatchedRequest) {
self.total_tokens += request.num_tokens();
self.has_prefill |= request.is_prefill;
self.has_decode |= !request.is_prefill;
self.max_seq_len = self.max_seq_len.max(request.seq_len);
self.requests.push(request);
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
pub fn len(&self) -> usize {
self.requests.len()
}
pub fn request_ids(&self) -> Vec<RequestId> {
self.requests.iter().map(|r| r.request_id).collect()
}
pub fn merge_prefill_decode(
prefill: Vec<BatchedRequest>,
decode: Vec<BatchedRequest>,
batch_id: u64,
) -> Self {
let mut batch = Self::new(batch_id);
for req in prefill {
batch.add(req);
}
for req in decode {
batch.add(req);
}
batch
}
pub fn split_by_type(&self) -> (Vec<&BatchedRequest>, Vec<&BatchedRequest>) {
let prefill: Vec<_> = self.requests.iter().filter(|r| r.is_prefill).collect();
let decode: Vec<_> = self.requests.iter().filter(|r| !r.is_prefill).collect();
(prefill, decode)
}
pub fn collect_input_ids(&self) -> Vec<Vec<u32>> {
self.requests.iter().map(|r| r.token_ids.clone()).collect()
}
pub fn collect_positions(&self) -> Vec<usize> {
self.requests.iter().map(|r| r.position_offset).collect()
}
pub fn collect_kv_slots(&self) -> Vec<usize> {
self.requests.iter().map(|r| r.kv_cache_slot).collect()
}
pub fn stats(&self) -> BatchStats {
let prefill_count = self.requests.iter().filter(|r| r.is_prefill).count();
let decode_count = self.requests.len() - prefill_count;
let prefill_tokens: usize = self
.requests
.iter()
.filter(|r| r.is_prefill)
.map(|r| r.num_tokens())
.sum();
BatchStats {
batch_id: self.batch_id,
total_requests: self.requests.len(),
prefill_requests: prefill_count,
decode_requests: decode_count,
total_tokens: self.total_tokens,
prefill_tokens,
decode_tokens: self.total_tokens - prefill_tokens,
max_seq_len: self.max_seq_len,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchStats {
pub batch_id: u64,
pub total_requests: usize,
pub prefill_requests: usize,
pub decode_requests: usize,
pub total_tokens: usize,
pub prefill_tokens: usize,
pub decode_tokens: usize,
pub max_seq_len: usize,
}
#[derive(Debug, Clone)]
pub struct PrefillTask {
pub request_id: RequestId,
pub tokens: Vec<u32>,
pub start_position: usize,
pub kv_cache_slot: usize,
pub block_table: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct DecodeTask {
pub request_id: RequestId,
pub input_token: u32,
pub position: usize,
pub kv_cache_slot: usize,
pub block_table: Vec<usize>,
pub context_len: usize,
}
#[derive(Debug)]
pub struct IterationPlan {
pub prefill_tasks: Vec<PrefillTask>,
pub decode_tasks: Vec<DecodeTask>,
pub evicted_requests: Vec<RequestId>,
pub swap_out_requests: Vec<RequestId>,
pub swap_in_requests: Vec<RequestId>,
}
impl IterationPlan {
pub fn empty() -> Self {
Self {
prefill_tasks: Vec::new(),
decode_tasks: Vec::new(),
evicted_requests: Vec::new(),
swap_out_requests: Vec::new(),
swap_in_requests: Vec::new(),
}
}
pub fn has_work(&self) -> bool {
!self.prefill_tasks.is_empty() || !self.decode_tasks.is_empty()
}
pub fn total_requests(&self) -> usize {
self.prefill_tasks.len() + self.decode_tasks.len()
}
pub fn total_tokens(&self) -> usize {
let prefill_tokens: usize = self.prefill_tasks.iter().map(|t| t.tokens.len()).sum();
let decode_tokens = self.decode_tasks.len(); prefill_tokens + decode_tokens
}
pub fn to_scheduled_batch(&self, batch_id: u64) -> ScheduledBatch {
let prefill: Vec<BatchedRequest> = self
.prefill_tasks
.iter()
.map(|t| {
BatchedRequest::prefill(
t.request_id,
t.tokens.clone(),
t.kv_cache_slot,
t.block_table.clone(),
)
})
.collect();
let decode: Vec<BatchedRequest> = self
.decode_tasks
.iter()
.map(|t| {
BatchedRequest::decode(
t.request_id,
t.input_token,
t.position,
t.kv_cache_slot,
t.block_table.clone(),
t.context_len,
)
})
.collect();
ScheduledBatch::merge_prefill_decode(prefill, decode, batch_id)
}
}
#[derive(Debug, Clone)]
pub struct TokenBudget {
pub max_prefill_tokens: usize,
pub max_decode_tokens: usize,
pub max_total_tokens: usize,
pub prefill_tokens: usize,
pub decode_tokens: usize,
}
impl TokenBudget {
pub fn new(max_prefill: usize, max_decode: usize, max_total: usize) -> Self {
Self {
max_prefill_tokens: max_prefill,
max_decode_tokens: max_decode,
max_total_tokens: max_total,
prefill_tokens: 0,
decode_tokens: 0,
}
}
pub fn reset(&mut self) {
self.prefill_tokens = 0;
self.decode_tokens = 0;
}
pub fn total_tokens(&self) -> usize {
self.prefill_tokens + self.decode_tokens
}
pub fn remaining_prefill(&self) -> usize {
let from_prefill_limit = self.max_prefill_tokens.saturating_sub(self.prefill_tokens);
let from_total_limit = self.max_total_tokens.saturating_sub(self.total_tokens());
from_prefill_limit.min(from_total_limit)
}
pub fn remaining_decode(&self) -> usize {
let from_decode_limit = self.max_decode_tokens.saturating_sub(self.decode_tokens);
let from_total_limit = self.max_total_tokens.saturating_sub(self.total_tokens());
from_decode_limit.min(from_total_limit)
}
pub fn try_allocate_prefill(&mut self, tokens: usize) -> bool {
if tokens <= self.remaining_prefill() {
self.prefill_tokens += tokens;
true
} else {
false
}
}
pub fn try_allocate_decode(&mut self) -> bool {
if self.remaining_decode() > 0 {
self.decode_tokens += 1;
true
} else {
false
}
}
pub fn is_exhausted(&self) -> bool {
self.total_tokens() >= self.max_total_tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batched_request() {
let prefill = BatchedRequest::prefill(RequestId::new(), vec![1, 2, 3, 4], 0, vec![0, 1]);
assert!(prefill.is_prefill);
assert_eq!(prefill.num_tokens(), 4);
assert_eq!(prefill.seq_len, 4);
let decode = BatchedRequest::decode(RequestId::new(), 5, 10, 1, vec![0, 1, 2], 10);
assert!(!decode.is_prefill);
assert_eq!(decode.num_tokens(), 1);
assert_eq!(decode.context_len, 10);
}
#[test]
fn test_scheduled_batch() {
let mut batch = ScheduledBatch::new(1);
batch.add(BatchedRequest::prefill(
RequestId::new(),
vec![1, 2, 3],
0,
vec![],
));
batch.add(BatchedRequest::decode(RequestId::new(), 4, 5, 1, vec![], 5));
assert_eq!(batch.len(), 2);
assert!(batch.has_prefill);
assert!(batch.has_decode);
assert_eq!(batch.total_tokens, 4);
let (prefill, decode) = batch.split_by_type();
assert_eq!(prefill.len(), 1);
assert_eq!(decode.len(), 1);
}
#[test]
fn test_token_budget() {
let mut budget = TokenBudget::new(100, 32, 128);
assert!(budget.try_allocate_prefill(50));
assert_eq!(budget.prefill_tokens, 50);
assert_eq!(budget.remaining_prefill(), 50);
assert!(budget.try_allocate_decode());
assert_eq!(budget.decode_tokens, 1);
assert!(!budget.try_allocate_prefill(60));
budget.reset();
assert_eq!(budget.total_tokens(), 0);
}
#[test]
fn test_iteration_plan() {
let plan = IterationPlan {
prefill_tasks: vec![PrefillTask {
request_id: RequestId::new(),
tokens: vec![1, 2, 3, 4, 5],
start_position: 0,
kv_cache_slot: 0,
block_table: vec![],
}],
decode_tasks: vec![DecodeTask {
request_id: RequestId::new(),
input_token: 6,
position: 10,
kv_cache_slot: 1,
block_table: vec![],
context_len: 10,
}],
evicted_requests: vec![],
swap_out_requests: vec![],
swap_in_requests: vec![],
};
assert!(plan.has_work());
assert_eq!(plan.total_requests(), 2);
assert_eq!(plan.total_tokens(), 6);
let batch = plan.to_scheduled_batch(42);
assert_eq!(batch.batch_id, 42);
assert_eq!(batch.len(), 2);
}
}