impl SlotManager {
pub fn new(num_slots: usize, max_context_length: usize) -> Self {
let slots = (0..num_slots).map(Slot::new).collect();
Self {
slots,
max_context_length,
next_request_id: 0,
}
}
pub fn num_slots(&self) -> usize {
self.slots.len()
}
pub fn num_idle_slots(&self) -> usize {
self.slots.iter().filter(|s| s.is_idle()).count()
}
pub fn num_active_slots(&self) -> usize {
self.slots.len() - self.num_idle_slots()
}
pub fn find_idle_slot(&self) -> Option<usize> {
self.slots.iter().position(Slot::is_idle)
}
pub fn assign_request(
&mut self,
input_tokens: Vec<u32>,
max_tokens: usize,
) -> Option<(usize, u64)> {
let slot_id = self.find_idle_slot()?;
let request_id = self.next_request_id;
self.next_request_id += 1;
self.slots[slot_id].assign(request_id, input_tokens, max_tokens);
Some((slot_id, request_id))
}
pub fn get_slot(&self, slot_id: usize) -> Option<&Slot> {
self.slots.get(slot_id)
}
pub fn get_slot_mut(&mut self, slot_id: usize) -> Option<&mut Slot> {
self.slots.get_mut(slot_id)
}
pub fn active_slots(&self) -> impl Iterator<Item = &Slot> {
self.slots.iter().filter(|s| !s.is_idle())
}
pub fn generating_slots(&self) -> impl Iterator<Item = &Slot> {
self.slots.iter().filter(|s| s.is_generating())
}
pub fn batch_slots(&self) -> Vec<usize> {
self.slots
.iter()
.enumerate()
.filter(|(_, s)| s.is_generating())
.map(|(i, _)| i)
.collect()
}
pub fn utilization(&self) -> f64 {
if self.slots.is_empty() {
0.0
} else {
self.num_active_slots() as f64 / self.slots.len() as f64
}
}
pub fn aggregate_tokens_per_second(&self) -> f64 {
self.slots.iter().map(Slot::tokens_per_second).sum()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BatchType {
Prefill,
Decode,
Mixed,
}
impl Default for BatchType {
fn default() -> Self {
Self::Decode
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BatchToken {
pub token_id: u32,
pub seq_idx: usize,
pub seq_pos: usize,
pub is_prompt: bool,
}
impl BatchToken {
pub fn new(token_id: u32, seq_idx: usize, seq_pos: usize, is_prompt: bool) -> Self {
Self {
token_id,
seq_idx,
seq_pos,
is_prompt,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MicroBatch {
pub tokens: Vec<BatchToken>,
pub seq_indices: Vec<usize>,
pub batch_type: BatchType,
pub max_seq_len: usize,
pub n_prompt_tokens: usize,
pub n_decode_tokens: usize,
}
impl MicroBatch {
pub fn new() -> Self {
Self {
tokens: Vec::new(),
seq_indices: Vec::new(),
batch_type: BatchType::Decode,
max_seq_len: 0,
n_prompt_tokens: 0,
n_decode_tokens: 0,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
tokens: Vec::with_capacity(capacity),
seq_indices: Vec::new(),
batch_type: BatchType::Decode,
max_seq_len: 0,
n_prompt_tokens: 0,
n_decode_tokens: 0,
}
}
pub fn add_token(&mut self, token: BatchToken) {
if token.is_prompt {
self.n_prompt_tokens += 1;
} else {
self.n_decode_tokens += 1;
}
if !self.seq_indices.contains(&token.seq_idx) {
self.seq_indices.push(token.seq_idx);
}
self.max_seq_len = self.max_seq_len.max(token.seq_pos + 1);
self.tokens.push(token);
self.update_batch_type();
}
fn update_batch_type(&mut self) {
self.batch_type = match (self.n_prompt_tokens > 0, self.n_decode_tokens > 0) {
(true, false) => BatchType::Prefill,
(true, true) => BatchType::Mixed,
(false, _) => BatchType::Decode,
};
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
pub fn num_sequences(&self) -> usize {
self.seq_indices.len()
}
pub fn is_prefill(&self) -> bool {
self.batch_type == BatchType::Prefill
}
pub fn is_decode(&self) -> bool {
self.batch_type == BatchType::Decode
}
pub fn is_mixed(&self) -> bool {
self.batch_type == BatchType::Mixed
}
pub fn token_ids(&self) -> Vec<u32> {
self.tokens.iter().map(|t| t.token_id).collect()
}
pub fn positions(&self) -> Vec<usize> {
self.tokens.iter().map(|t| t.seq_pos).collect()
}
pub fn clear(&mut self) {
self.tokens.clear();
self.seq_indices.clear();
self.batch_type = BatchType::Decode;
self.max_seq_len = 0;
self.n_prompt_tokens = 0;
self.n_decode_tokens = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SequenceBatchEntry {
pub seq_idx: usize,
pub slot_id: usize,
pub request_id: u64,
pub position: usize,
pub tokens: Vec<u32>,
pub is_prefill: bool,
}
impl SequenceBatchEntry {
pub fn new(seq_idx: usize, slot_id: usize, request_id: u64) -> Self {
Self {
seq_idx,
slot_id,
request_id,
position: 0,
tokens: Vec::new(),
is_prefill: true,
}
}
pub fn with_tokens(mut self, tokens: Vec<u32>) -> Self {
self.tokens = tokens;
self
}
pub fn at_position(mut self, position: usize) -> Self {
self.position = position;
self
}
pub fn decoding(mut self) -> Self {
self.is_prefill = false;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SequenceBatch {
pub sequences: Vec<SequenceBatchEntry>,
pub max_batch_size: usize,
pub utilization: f64,
}
impl SequenceBatch {
pub fn new(max_batch_size: usize) -> Self {
Self {
sequences: Vec::with_capacity(max_batch_size),
max_batch_size,
utilization: 0.0,
}
}
pub fn add_sequence(&mut self, entry: SequenceBatchEntry) -> bool {
if self.sequences.len() >= self.max_batch_size {
return false;
}
self.sequences.push(entry);
self.update_utilization();
true
}
pub fn remove_sequence(&mut self, seq_idx: usize) -> Option<SequenceBatchEntry> {
let pos = self.sequences.iter().position(|s| s.seq_idx == seq_idx)?;
let entry = self.sequences.remove(pos);
self.update_utilization();
Some(entry)
}
fn update_utilization(&mut self) {
self.utilization = if self.max_batch_size > 0 {
self.sequences.len() as f64 / self.max_batch_size as f64
} else {
0.0
};
}
pub fn len(&self) -> usize {
self.sequences.len()
}
pub fn is_empty(&self) -> bool {
self.sequences.is_empty()
}
pub fn is_full(&self) -> bool {
self.sequences.len() >= self.max_batch_size
}
pub fn prefill_sequences(&self) -> impl Iterator<Item = &SequenceBatchEntry> {
self.sequences.iter().filter(|s| s.is_prefill)
}
pub fn decode_sequences(&self) -> impl Iterator<Item = &SequenceBatchEntry> {
self.sequences.iter().filter(|s| !s.is_prefill)
}
pub fn num_prefill(&self) -> usize {
self.sequences.iter().filter(|s| s.is_prefill).count()
}
pub fn num_decode(&self) -> usize {
self.sequences.iter().filter(|s| !s.is_prefill).count()
}
pub fn clear(&mut self) {
self.sequences.clear();
self.utilization = 0.0;
}
pub fn get(&self, seq_idx: usize) -> Option<&SequenceBatchEntry> {
self.sequences.iter().find(|s| s.seq_idx == seq_idx)
}
pub fn get_mut(&mut self, seq_idx: usize) -> Option<&mut SequenceBatchEntry> {
self.sequences.iter_mut().find(|s| s.seq_idx == seq_idx)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_ubatch_tokens: usize,
pub max_sbatch_sequences: usize,
pub prefer_pure_decode: bool,
pub max_context_length: usize,
pub dynamic_batching: bool,
}