use crate::runtime::config::SchedulerTickBudget;
use crate::runtime::llama_token;
use crate::runtime::numeric::{
positive_i32_to_usize, saturating_u32_to_i32, saturating_usize_to_i32,
};
use crate::runtime::request::GenerateRequestId;
use super::SlotState;
#[cfg(test)]
mod apply_results;
mod helpers;
use helpers::resolve_prefill_slice_cap;
const FAST_OCCUPIED_SLOT_BITS: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BatchContributionKind {
Prefill = 0,
Decode,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BatchContribution {
pub slot_index: usize,
pub request_id: GenerateRequestId,
pub kind: BatchContributionKind,
pub token: llama_token,
pub position: i32,
pub request_logits: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct SharedBatchPlan {
pub contributions: Vec<BatchContribution>,
pub prefill_token_count: i32,
pub decode_token_count: i32,
pub occupied_slot_count: i32,
active_prefill_slots: Vec<usize>,
in_tick_offset: Vec<usize>,
occupied_overflow_slots: Vec<usize>,
}
impl SharedBatchPlan {
pub fn with_capacities(max_contributions: usize, max_slots: usize) -> Self {
Self {
contributions: Vec::with_capacity(max_contributions),
active_prefill_slots: Vec::with_capacity(max_slots),
in_tick_offset: Vec::with_capacity(max_slots),
occupied_overflow_slots: Vec::with_capacity(
max_slots.saturating_sub(FAST_OCCUPIED_SLOT_BITS),
),
..Self::default()
}
}
pub fn reset(&mut self) {
self.contributions.clear();
self.prefill_token_count = 0;
self.decode_token_count = 0;
self.occupied_slot_count = 0;
self.active_prefill_slots.clear();
self.in_tick_offset.clear();
self.occupied_overflow_slots.clear();
}
fn erase_active_prefill_slot(&mut self, index: usize) {
self.active_prefill_slots.remove(index);
self.in_tick_offset.remove(index);
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BatchPlanner;
impl BatchPlanner {
#[cfg(test)]
pub fn build_policy_batch(
&self,
slots: &[SlotState],
decode_slots: &[usize],
prefill_slots: &[usize],
budget: SchedulerTickBudget,
prefill_chunk_size: i32,
) -> SharedBatchPlan {
let mut plan = SharedBatchPlan::default();
self.build_policy_batch_into(
&mut plan,
slots,
decode_slots,
prefill_slots,
budget,
prefill_chunk_size,
);
plan
}
pub fn build_policy_batch_into(
&self,
plan: &mut SharedBatchPlan,
slots: &[SlotState],
decode_slots: &[usize],
prefill_slots: &[usize],
budget: SchedulerTickBudget,
prefill_chunk_size: i32,
) {
plan.reset();
if budget.total_token_budget <= 0 {
return;
}
let reserve_capacity = positive_i32_to_usize(budget.total_token_budget).unwrap_or(1);
plan.contributions.reserve(reserve_capacity);
let mut remaining_decode_budget = budget.effective_decode_budget();
let mut remaining_prefill_budget = budget.effective_prefill_budget();
let has_decode_pressure = !decode_slots.is_empty();
for &slot_index in decode_slots {
if remaining_decode_budget <= 0 {
break;
}
let Some(slot) = slots.get(slot_index) else {
continue;
};
let Some(request) = slot.request() else {
continue;
};
let Some(&token) = slot.generated_tokens.last() else {
continue;
};
plan.contributions.push(decode_contribution(
slot_index,
request.id,
token,
slot.mirror.n_past,
));
plan.decode_token_count = plan.decode_token_count.saturating_add(1);
remaining_decode_budget -= 1;
}
plan.active_prefill_slots.reserve(prefill_slots.len());
plan.in_tick_offset.reserve(prefill_slots.len());
for &slot_index in prefill_slots {
let Some(slot) = slots.get(slot_index) else {
continue;
};
let Some(request) = slot.request() else {
continue;
};
if slot.prefill_cursor >= request.prompt_tokens.len() {
continue;
}
plan.active_prefill_slots.push(slot_index);
plan.in_tick_offset.push(0_usize);
}
let mut next_prefill_slot_index = 0;
while remaining_prefill_budget > 0 && !plan.active_prefill_slots.is_empty() {
if next_prefill_slot_index >= plan.active_prefill_slots.len() {
next_prefill_slot_index = 0;
}
let slot_index = plan.active_prefill_slots[next_prefill_slot_index];
let Some(slot) = slots.get(slot_index) else {
plan.erase_active_prefill_slot(next_prefill_slot_index);
continue;
};
let Some(request) = slot.request() else {
plan.erase_active_prefill_slot(next_prefill_slot_index);
continue;
};
if slot.prefill_cursor >= request.prompt_tokens.len() {
plan.erase_active_prefill_slot(next_prefill_slot_index);
continue;
}
let slot_contribution_start = plan.contributions.len();
let slot_chunk_budget = resolve_prefill_slice_cap(
budget,
prefill_chunk_size,
remaining_prefill_budget,
plan.active_prefill_slots.len(),
has_decode_pressure,
);
let resume_offset = plan.in_tick_offset[next_prefill_slot_index];
let mut remaining_slot_budget = slot_chunk_budget;
let prompt_end = request.prompt_tokens.len();
for token_index in (slot.prefill_cursor + resume_offset)..prompt_end {
if remaining_slot_budget <= 0 || remaining_prefill_budget <= 0 {
break;
}
let Ok(position) = i32::try_from(token_index) else {
break;
};
plan.contributions.push(prefill_contribution(
slot_index,
request.id,
request.prompt_tokens[token_index],
position,
));
plan.prefill_token_count = plan.prefill_token_count.saturating_add(1);
remaining_slot_budget -= 1;
remaining_prefill_budget -= 1;
}
let added_this_iteration = plan.contributions.len() - slot_contribution_start;
let total_added = resume_offset + added_this_iteration;
plan.in_tick_offset[next_prefill_slot_index] = total_added;
let slot_reached_tick_stop = slot.prefill_cursor + total_added >= prompt_end;
let slot_completed_prompt =
slot.prefill_cursor + total_added >= request.prompt_tokens.len();
if added_this_iteration > 0 && slot_completed_prompt {
if let Some(last) = plan.contributions.last_mut() {
last.request_logits = true;
}
}
if slot_completed_prompt || slot_reached_tick_stop {
plan.erase_active_prefill_slot(next_prefill_slot_index);
continue;
}
next_prefill_slot_index += 1;
}
let mut occupied_mask: u64 = 0;
for contribution in &plan.contributions {
if contribution.slot_index < FAST_OCCUPIED_SLOT_BITS {
occupied_mask |= 1u64 << contribution.slot_index;
} else if !plan
.occupied_overflow_slots
.contains(&contribution.slot_index)
{
plan.occupied_overflow_slots.push(contribution.slot_index);
}
}
plan.occupied_slot_count = saturating_u32_to_i32(occupied_mask.count_ones())
.saturating_add(saturating_usize_to_i32(plan.occupied_overflow_slots.len()));
}
#[cfg(test)]
pub fn apply_decode_results(&self, slots: &mut [SlotState], plan: &SharedBatchPlan) {
apply_results::apply_decode_results(slots, plan);
}
}
fn decode_contribution(
slot_index: usize,
request_id: GenerateRequestId,
token: llama_token,
position: i32,
) -> BatchContribution {
BatchContribution {
slot_index,
request_id,
kind: BatchContributionKind::Decode,
token,
position,
request_logits: true,
}
}
fn prefill_contribution(
slot_index: usize,
request_id: GenerateRequestId,
token: llama_token,
position: i32,
) -> BatchContribution {
BatchContribution {
slot_index,
request_id,
kind: BatchContributionKind::Prefill,
token,
position,
request_logits: false,
}
}
#[cfg(test)]
#[path = "../../../tests/runtime/scheduler/batch_planner_tests.rs"]
mod batch_planner_tests;