use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EngineHandle {
pub engine_id: u32,
pub share_weight: u32,
}
#[derive(Debug, Clone)]
pub struct EngineTokenBudget {
pub engine_id: u32,
pub max_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct PreemptionRequest {
pub requesting_engine_id: u32,
pub blocks_needed: usize,
pub request_priority: u8,
}
#[derive(Debug, Clone)]
pub struct PreemptionResult {
pub donor_engine_id: u32,
pub blocks_freed: usize,
}
#[derive(Debug)]
struct EngineState {
share_weight: u32,
last_used_tokens: usize,
was_active: bool,
}
#[derive(Debug)]
pub struct GlobalKvScheduler {
global_max_tokens: usize,
engines: HashMap<u32, EngineState>,
engine_order: Vec<u32>,
max_single_engine_permille: u32,
}
impl GlobalKvScheduler {
pub fn new(global_max_tokens: usize) -> Self {
Self {
global_max_tokens,
engines: HashMap::new(),
engine_order: Vec::new(),
max_single_engine_permille: 900,
}
}
pub fn with_max_single_engine_permille(mut self, permille: u32) -> Self {
self.max_single_engine_permille = permille.clamp(100, 1000);
self
}
pub fn register(&mut self, handle: EngineHandle) {
if !self.engines.contains_key(&handle.engine_id) {
self.engine_order.push(handle.engine_id);
}
self.engines.insert(
handle.engine_id,
EngineState {
share_weight: handle.share_weight.max(1),
last_used_tokens: 0,
was_active: false,
},
);
}
pub fn deregister(&mut self, engine_id: u32) {
self.engines.remove(&engine_id);
self.engine_order.retain(|&id| id != engine_id);
}
pub fn set_active(&mut self, engine_id: u32, active: bool) {
if let Some(state) = self.engines.get_mut(&engine_id) {
state.was_active = active;
}
}
pub fn report_usage(&mut self, engine_id: u32, used_tokens: usize) {
if let Some(state) = self.engines.get_mut(&engine_id) {
state.last_used_tokens = used_tokens;
}
}
pub fn allocate_budgets(&self) -> Vec<EngineTokenBudget> {
if self.engines.is_empty() || self.global_max_tokens == 0 {
return Vec::new();
}
let active_total_weight: u64 = self
.engine_order
.iter()
.filter_map(|id| self.engines.get(id))
.filter(|s| s.was_active)
.map(|s| s.share_weight as u64)
.sum();
let treat_all_active = active_total_weight == 0;
let all_total_weight: u64 = self
.engines
.values()
.map(|s| s.share_weight as u64)
.sum::<u64>()
.max(1);
let cap_tokens = (self.global_max_tokens as u64 * self.max_single_engine_permille as u64
/ 1000) as usize;
let mut budgets: Vec<EngineTokenBudget> = Vec::with_capacity(self.engines.len());
let mut idle_pool: usize = 0;
let mut natural_sum: usize = 0;
for &engine_id in &self.engine_order {
let Some(state) = self.engines.get(&engine_id) else {
continue;
};
let is_active = state.was_active || treat_all_active;
let natural = (self.global_max_tokens as u64 * state.share_weight as u64
/ all_total_weight) as usize;
natural_sum += natural;
let max_tokens = if is_active {
natural.min(cap_tokens)
} else {
idle_pool += natural;
0
};
budgets.push(EngineTokenBudget {
engine_id,
max_tokens,
});
}
if idle_pool > 0 {
if let Some(budget) = budgets.iter_mut().find(|b| b.max_tokens > 0) {
budget.max_tokens += idle_pool;
}
}
let rounding = self.global_max_tokens.saturating_sub(natural_sum);
if rounding > 0 {
if let Some(budget) = budgets.iter_mut().find(|b| b.max_tokens > 0) {
let headroom = cap_tokens.saturating_sub(budget.max_tokens);
budget.max_tokens += rounding.min(headroom);
}
}
budgets
}
pub fn budget_for(&self, engine_id: u32) -> Option<usize> {
self.allocate_budgets()
.into_iter()
.find(|b| b.engine_id == engine_id)
.map(|b| b.max_tokens)
}
pub fn engine_count(&self) -> usize {
self.engines.len()
}
pub fn find_preemption_donor(
&self,
request: &PreemptionRequest,
engine_priorities: &HashMap<u32, u8>,
engine_freeable_blocks: &HashMap<u32, usize>,
) -> Option<u32> {
self.engine_order
.iter()
.filter(|&&id| id != request.requesting_engine_id)
.filter(|&&id| self.engines.contains_key(&id))
.filter(|&&id| {
engine_priorities
.get(&id)
.map(|&p| p < request.request_priority)
.unwrap_or(false)
})
.filter(|&&id| {
engine_freeable_blocks
.get(&id)
.map(|&b| b >= request.blocks_needed)
.unwrap_or(false)
})
.min_by_key(|&&id| engine_priorities.get(&id).copied().unwrap_or(u8::MAX))
.copied()
}
pub fn global_max_tokens(&self) -> usize {
self.global_max_tokens
}
pub fn set_global_max_tokens(&mut self, tokens: usize) {
self.global_max_tokens = tokens;
}
}
#[cfg(test)]
mod global_scheduler_tests {
use super::*;
fn make_scheduler(total: usize, engines: &[(u32, u32, bool)]) -> GlobalKvScheduler {
let mut sched = GlobalKvScheduler::new(total);
for &(id, weight, active) in engines {
sched.register(EngineHandle {
engine_id: id,
share_weight: weight,
});
sched.set_active(id, active);
}
sched
}
#[test]
fn equal_weights_split_evenly() {
let sched = make_scheduler(1000, &[(0, 1, true), (1, 1, true)]);
let budgets = sched.allocate_budgets();
assert_eq!(budgets.len(), 2);
let total: usize = budgets.iter().map(|b| b.max_tokens).sum();
assert_eq!(total, 1000);
for b in &budgets {
assert!(b.max_tokens >= 490 && b.max_tokens <= 510, "{b:?}");
}
}
#[test]
fn weighted_distribution() {
let sched = make_scheduler(1000, &[(0, 1, true), (1, 3, true)]);
let budgets = sched.allocate_budgets();
let b0 = budgets
.iter()
.find(|b| b.engine_id == 0)
.unwrap()
.max_tokens;
let b1 = budgets
.iter()
.find(|b| b.engine_id == 1)
.unwrap()
.max_tokens;
assert!(b0 >= 249 && b0 <= 251, "b0={b0}");
assert!(b1 >= 749 && b1 <= 751, "b1={b1}");
assert_eq!(b0 + b1, 1000);
}
#[test]
fn idle_engine_gets_zero() {
let sched = make_scheduler(1000, &[(0, 1, true), (1, 1, false)]);
let budgets = sched.allocate_budgets();
let b1 = budgets
.iter()
.find(|b| b.engine_id == 1)
.unwrap()
.max_tokens;
assert_eq!(b1, 0, "idle engine should get zero budget");
}
#[test]
fn active_engine_absorbs_idle_share() {
let sched = make_scheduler(1000, &[(0, 1, true), (1, 1, false)]);
let budgets = sched.allocate_budgets();
let b0 = budgets
.iter()
.find(|b| b.engine_id == 0)
.unwrap()
.max_tokens;
assert_eq!(b0, 1000);
}
#[test]
fn single_engine_cap_respected() {
let mut sched = GlobalKvScheduler::new(1000).with_max_single_engine_permille(500);
sched.register(EngineHandle {
engine_id: 0,
share_weight: 1,
});
sched.set_active(0, true);
let budgets = sched.allocate_budgets();
let b0 = budgets[0].max_tokens;
assert!(b0 <= 500, "b0={b0} exceeds cap");
}
#[test]
fn deregister_removes_engine() {
let mut sched = make_scheduler(1000, &[(0, 1, true), (1, 1, true)]);
sched.deregister(1);
assert_eq!(sched.engine_count(), 1);
let budgets = sched.allocate_budgets();
assert_eq!(budgets.len(), 1);
assert_eq!(budgets[0].engine_id, 0);
}
#[test]
fn find_preemption_donor_picks_lowest_priority() {
let sched = make_scheduler(1000, &[(0, 1, true), (1, 1, true), (2, 1, true)]);
let req = PreemptionRequest {
requesting_engine_id: 0,
blocks_needed: 10,
request_priority: 5,
};
let priorities: HashMap<u32, u8> = [(1, 1), (2, 3)].into();
let freeable: HashMap<u32, usize> = [(1, 20), (2, 20)].into();
let donor = sched.find_preemption_donor(&req, &priorities, &freeable);
assert_eq!(donor, Some(1));
}
#[test]
fn no_donor_when_all_higher_priority() {
let sched = make_scheduler(1000, &[(0, 1, true), (1, 1, true)]);
let req = PreemptionRequest {
requesting_engine_id: 0,
blocks_needed: 10,
request_priority: 1, };
let priorities: HashMap<u32, u8> = [(1, 10)].into(); let freeable: HashMap<u32, usize> = [(1, 20)].into();
let donor = sched.find_preemption_donor(&req, &priorities, &freeable);
assert_eq!(donor, None);
}
}