#[cfg(feature = "gpu")]
pub struct PrefixCacheEntry {
pub tokens: Vec<u32>,
pub k_cache: Vec<Vec<f32>>,
pub v_cache: Vec<Vec<f32>>,
pub last_access: std::time::Instant,
pub hit_count: u64,
}
#[cfg(feature = "gpu")]
impl PrefixCache {
pub fn new(max_entries: usize) -> Self {
Self {
entries: std::sync::Mutex::new(std::collections::HashMap::with_capacity(max_entries)),
max_entries,
hits: std::sync::atomic::AtomicU64::new(0),
misses: std::sync::atomic::AtomicU64::new(0),
evictions: std::sync::atomic::AtomicU64::new(0),
}
}
fn lock_entries(
&self,
) -> std::sync::MutexGuard<'_, std::collections::HashMap<u64, PrefixCacheEntry>> {
self.entries.lock().expect("mutex poisoned")
}
fn inc(counter: &std::sync::atomic::AtomicU64) {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
fn load(counter: &std::sync::atomic::AtomicU64) -> u64 {
counter.load(std::sync::atomic::Ordering::Relaxed)
}
fn hash_tokens(tokens: &[u32]) -> u64 {
const FNV_OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0100_0000_01b3;
let mut hash = FNV_OFFSET;
for &token in tokens {
hash ^= token as u64;
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[allow(clippy::type_complexity)]
pub fn lookup(&self, tokens: &[u32]) -> Option<(Vec<Vec<f32>>, Vec<Vec<f32>>)> {
let hash = Self::hash_tokens(tokens);
let mut entries = self.lock_entries();
if let Some(entry) = entries.get_mut(&hash) {
if entry.tokens == tokens {
Self::inc(&self.hits);
entry.last_access = std::time::Instant::now();
entry.hit_count += 1;
return Some((entry.k_cache.clone(), entry.v_cache.clone()));
}
}
Self::inc(&self.misses);
None
}
pub fn insert(&self, tokens: Vec<u32>, k_cache: Vec<Vec<f32>>, v_cache: Vec<Vec<f32>>) {
let hash = Self::hash_tokens(&tokens);
let mut entries = self.lock_entries();
if entries.len() >= self.max_entries {
if let Some((&oldest_hash, _)) = entries.iter().min_by_key(|(_, e)| e.last_access) {
entries.remove(&oldest_hash);
Self::inc(&self.evictions);
}
}
entries.insert(
hash,
PrefixCacheEntry {
tokens,
k_cache,
v_cache,
last_access: std::time::Instant::now(),
hit_count: 0,
},
);
}
pub fn contains(&self, tokens: &[u32]) -> bool {
let hash = Self::hash_tokens(tokens);
self.lock_entries().contains_key(&hash)
}
pub fn stats(&self) -> PrefixCacheStats {
let hits = Self::load(&self.hits);
let misses = Self::load(&self.misses);
let total = hits + misses;
PrefixCacheStats {
hits,
misses,
evictions: Self::load(&self.evictions),
entries: self.lock_entries().len(),
hit_rate: if total > 0 {
hits as f64 / total as f64
} else {
0.0
},
}
}
pub fn clear(&self) {
self.lock_entries().clear();
}
pub fn memory_usage_bytes(&self) -> usize {
self.lock_entries()
.values()
.map(|e| {
let k_bytes: usize = e.k_cache.iter().map(|v| v.len() * 4).sum();
let v_bytes: usize = e.v_cache.iter().map(|v| v.len() * 4).sum();
let token_bytes = e.tokens.len() * 4;
k_bytes + v_bytes + token_bytes
})
.sum()
}
}
#[cfg(feature = "gpu")]
impl Default for PrefixCache {
fn default() -> Self {
Self::new(16) }
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub struct PrefixCacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub entries: usize,
pub hit_rate: f64,
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MultiRequestState {
Pending,
Prefilling,
Decoding,
Completed,
Preempted,
}
#[cfg(feature = "gpu")]
#[derive(Clone)]
pub struct MultiSchedulerRequest {
pub id: u64,
pub tokens: Vec<u32>,
pub generated: Vec<u32>,
pub max_tokens: usize,
pub state: MultiRequestState,
pub kv_position: usize,
pub arrival_time: std::time::Instant,
pub first_token_time: Option<std::time::Instant>,
}
#[cfg(feature = "gpu")]
impl MultiSchedulerRequest {
pub fn new(id: u64, tokens: Vec<u32>, max_tokens: usize) -> Self {
Self {
id,
tokens,
generated: Vec::with_capacity(max_tokens),
max_tokens,
state: MultiRequestState::Pending,
kv_position: 0,
arrival_time: std::time::Instant::now(),
first_token_time: None,
}
}
pub fn is_complete(&self) -> bool {
self.state == MultiRequestState::Completed || self.generated.len() >= self.max_tokens
}
pub fn ttft_ms(&self) -> Option<f64> {
self.first_token_time
.map(|t| t.duration_since(self.arrival_time).as_secs_f64() * 1000.0)
}
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingPolicy {
Fcfs,
Sjf,
RoundRobin,
}
#[cfg(feature = "gpu")]
pub struct MultiRequestScheduler {
pending: std::sync::Mutex<std::collections::VecDeque<MultiSchedulerRequest>>,
active: std::sync::Mutex<Vec<MultiSchedulerRequest>>,
completed: std::sync::Mutex<Vec<MultiSchedulerRequest>>,
max_batch_size: usize,
max_concurrent: usize,
policy: SchedulingPolicy,
next_id: std::sync::atomic::AtomicU64,
pub requests_submitted: std::sync::atomic::AtomicU64,
pub requests_completed: std::sync::atomic::AtomicU64,
pub tokens_generated: std::sync::atomic::AtomicU64,
pub batch_iterations: std::sync::atomic::AtomicU64,
}
#[cfg(feature = "gpu")]
impl MultiRequestScheduler {
fn new_counter() -> std::sync::atomic::AtomicU64 {
std::sync::atomic::AtomicU64::new(0)
}
fn inc(counter: &std::sync::atomic::AtomicU64) {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
fn inc_by(counter: &std::sync::atomic::AtomicU64, n: u64) {
counter.fetch_add(n, std::sync::atomic::Ordering::Relaxed);
}
fn load(counter: &std::sync::atomic::AtomicU64) -> u64 {
counter.load(std::sync::atomic::Ordering::Relaxed)
}
fn lock_pending(
&self,
) -> std::sync::MutexGuard<'_, std::collections::VecDeque<MultiSchedulerRequest>> {
self.pending.lock().expect("mutex poisoned")
}
fn lock_active(&self) -> std::sync::MutexGuard<'_, Vec<MultiSchedulerRequest>> {
self.active.lock().expect("mutex poisoned")
}
fn lock_completed(&self) -> std::sync::MutexGuard<'_, Vec<MultiSchedulerRequest>> {
self.completed.lock().expect("mutex poisoned")
}
pub fn new(max_batch_size: usize, max_concurrent: usize, policy: SchedulingPolicy) -> Self {
Self {
pending: std::sync::Mutex::new(std::collections::VecDeque::new()),
active: std::sync::Mutex::new(Vec::with_capacity(max_concurrent)),
completed: std::sync::Mutex::new(Vec::new()),
max_batch_size,
max_concurrent,
policy,
next_id: Self::new_counter(),
requests_submitted: Self::new_counter(),
requests_completed: Self::new_counter(),
tokens_generated: Self::new_counter(),
batch_iterations: Self::new_counter(),
}
}
pub fn submit(&self, tokens: Vec<u32>, max_tokens: usize) -> u64 {
let id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let request = MultiSchedulerRequest::new(id, tokens, max_tokens);
self.lock_pending().push_back(request);
Self::inc(&self.requests_submitted);
id
}
pub fn get_decode_batch(&self) -> Vec<(u64, usize)> {
let mut active = self.lock_active();
let mut pending = self.lock_pending();
while active.len() < self.max_concurrent && !pending.is_empty() {
if let Some(mut req) = pending.pop_front() {
req.state = MultiRequestState::Decoding;
active.push(req);
}
}
match self.policy {
SchedulingPolicy::Fcfs => {
},
SchedulingPolicy::Sjf => {
active.sort_by_key(|r| r.max_tokens - r.generated.len());
},
SchedulingPolicy::RoundRobin => {
if active.len() > 1 {
let first = active.remove(0);
active.push(first);
}
},
}
active
.iter()
.filter(|r| r.state == MultiRequestState::Decoding)
.take(self.max_batch_size)
.map(|r| (r.id, r.kv_position))
.collect()
}
pub fn record_token(&self, request_id: u64, token: u32) {
let mut active = self.lock_active();
if let Some(req) = active.iter_mut().find(|r| r.id == request_id) {
if req.first_token_time.is_none() {
req.first_token_time = Some(std::time::Instant::now());
}
req.generated.push(token);
req.kv_position += 1;
Self::inc(&self.tokens_generated);
if req.is_complete() {
req.state = MultiRequestState::Completed;
}
}
}
pub fn collect_completed(&self) -> Vec<MultiSchedulerRequest> {
let mut active = self.lock_active();
let mut completed = self.lock_completed();
let (done, still_active): (Vec<_>, Vec<_>) = active
.drain(..)
.partition(|r| r.state == MultiRequestState::Completed);
*active = still_active;
Self::inc_by(&self.requests_completed, done.len() as u64);
completed.extend(done.iter().cloned());
done
}
pub fn step(&self) {
Self::inc(&self.batch_iterations);
}
pub fn stats(&self) -> MultiRequestStats {
let submitted = Self::load(&self.requests_submitted);
let completed = Self::load(&self.requests_completed);
let tokens = Self::load(&self.tokens_generated);
let iterations = Self::load(&self.batch_iterations);
let pending = self.lock_pending().len();
let active = self.lock_active().len();
MultiRequestStats {
requests_submitted: submitted,
requests_completed: completed,
tokens_generated: tokens,
batch_iterations: iterations,
pending_requests: pending,
active_requests: active,
avg_batch_size: if iterations > 0 {
tokens as f64 / iterations as f64
} else {
0.0
},
}
}
}