use super::request::RequestId;
use crate::error::{Result, RuvLLMError};
use crate::kv_cache::{KvCacheConfig, TwoTierKvCache};
use parking_lot::RwLock;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct KvCachePoolConfig {
pub num_slots: usize,
pub max_seq_len: usize,
pub block_size: usize,
pub total_blocks: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
}
impl Default for KvCachePoolConfig {
fn default() -> Self {
Self {
num_slots: 256,
max_seq_len: 4096,
block_size: 16,
total_blocks: 4096,
num_kv_heads: 8,
head_dim: 128,
num_layers: 32,
}
}
}
impl KvCachePoolConfig {
pub fn blocks_for_seq_len(&self, seq_len: usize) -> usize {
(seq_len + self.block_size - 1) / self.block_size
}
pub fn bytes_per_block(&self) -> usize {
2 * self.num_kv_heads * self.head_dim * self.block_size * self.num_layers * 2
}
pub fn total_memory(&self) -> usize {
self.total_blocks * self.bytes_per_block()
}
}
#[derive(Debug, Clone)]
pub struct KvCacheAllocation {
pub slot_id: usize,
pub current_length: usize,
pub max_length: usize,
pub block_table: Vec<usize>,
pub num_blocks: usize,
pub request_id: RequestId,
pub is_active: bool,
}
impl KvCacheAllocation {
pub fn new(slot_id: usize, request_id: RequestId, max_length: usize) -> Self {
Self {
slot_id,
current_length: 0,
max_length,
block_table: Vec::new(),
num_blocks: 0,
request_id,
is_active: true,
}
}
pub fn remaining(&self) -> usize {
self.max_length.saturating_sub(self.current_length)
}
pub fn can_extend(&self, additional_tokens: usize) -> bool {
self.current_length + additional_tokens <= self.max_length
}
}
#[derive(Debug)]
pub struct KvCacheManager {
config: KvCachePoolConfig,
allocations: RwLock<HashMap<RequestId, KvCacheAllocation>>,
free_slots: RwLock<VecDeque<usize>>,
free_blocks: RwLock<VecDeque<usize>>,
active_allocations: AtomicUsize,
allocated_blocks: AtomicUsize,
caches: Vec<Arc<TwoTierKvCache>>,
swap_space: RwLock<HashMap<RequestId, SwappedCache>>,
}
#[derive(Debug, Clone)]
pub struct SwappedCache {
pub request_id: RequestId,
pub original_slot: usize,
pub keys: Vec<f32>,
pub values: Vec<f32>,
pub seq_len: usize,
pub block_table: Vec<usize>,
}
impl KvCacheManager {
pub fn new(config: KvCachePoolConfig) -> Self {
let free_slots: VecDeque<usize> = (0..config.num_slots).collect();
let free_blocks: VecDeque<usize> = (0..config.total_blocks).collect();
let kv_config = KvCacheConfig {
tail_length: 256,
max_tokens: config.max_seq_len,
num_kv_heads: config.num_kv_heads,
head_dim: config.head_dim,
..Default::default()
};
let caches: Vec<_> = (0..config.num_slots)
.map(|_| Arc::new(TwoTierKvCache::new(kv_config.clone())))
.collect();
Self {
config,
allocations: RwLock::new(HashMap::new()),
free_slots: RwLock::new(free_slots),
free_blocks: RwLock::new(free_blocks),
active_allocations: AtomicUsize::new(0),
allocated_blocks: AtomicUsize::new(0),
caches,
swap_space: RwLock::new(HashMap::new()),
}
}
pub fn allocate(&mut self, request_id: RequestId, max_tokens: usize) -> Result<usize> {
let mut free_slots = self.free_slots.write();
let slot_id = free_slots.pop_front().ok_or_else(|| {
RuvLLMError::OutOfMemory("No free KV cache slots available".to_string())
})?;
let blocks_needed = self.config.blocks_for_seq_len(max_tokens);
let mut free_blocks = self.free_blocks.write();
if free_blocks.len() < blocks_needed {
free_slots.push_front(slot_id);
return Err(RuvLLMError::OutOfMemory(format!(
"Not enough blocks: need {}, have {}",
blocks_needed,
free_blocks.len()
)));
}
let block_table: Vec<usize> = (0..blocks_needed)
.filter_map(|_| free_blocks.pop_front())
.collect();
let mut allocation = KvCacheAllocation::new(slot_id, request_id, max_tokens);
allocation.block_table = block_table.clone();
allocation.num_blocks = blocks_needed;
self.allocations.write().insert(request_id, allocation);
self.active_allocations.fetch_add(1, Ordering::Relaxed);
self.allocated_blocks
.fetch_add(blocks_needed, Ordering::Relaxed);
self.caches[slot_id].clear();
Ok(slot_id)
}
pub fn extend(&mut self, request_id: RequestId, new_tokens: usize) -> Result<()> {
let mut allocations = self.allocations.write();
let allocation = allocations.get_mut(&request_id).ok_or_else(|| {
RuvLLMError::NotFound(format!("No allocation for request {}", request_id))
})?;
let new_length = allocation.current_length + new_tokens;
if new_length > allocation.max_length {
return Err(RuvLLMError::OutOfMemory(format!(
"Cannot extend: {} + {} > {}",
allocation.current_length, new_tokens, allocation.max_length
)));
}
let current_blocks = allocation.num_blocks;
let needed_blocks = self.config.blocks_for_seq_len(new_length);
if needed_blocks > current_blocks {
let additional_blocks = needed_blocks - current_blocks;
let mut free_blocks = self.free_blocks.write();
if free_blocks.len() < additional_blocks {
return Err(RuvLLMError::OutOfMemory(format!(
"Not enough blocks to extend: need {}, have {}",
additional_blocks,
free_blocks.len()
)));
}
for _ in 0..additional_blocks {
if let Some(block) = free_blocks.pop_front() {
allocation.block_table.push(block);
}
}
allocation.num_blocks = needed_blocks;
self.allocated_blocks
.fetch_add(additional_blocks, Ordering::Relaxed);
}
allocation.current_length = new_length;
Ok(())
}
pub fn free(&mut self, request_id: RequestId) {
let mut allocations = self.allocations.write();
if let Some(allocation) = allocations.remove(&request_id) {
self.free_slots.write().push_back(allocation.slot_id);
let mut free_blocks = self.free_blocks.write();
for block in allocation.block_table {
free_blocks.push_back(block);
}
self.active_allocations.fetch_sub(1, Ordering::Relaxed);
self.allocated_blocks
.fetch_sub(allocation.num_blocks, Ordering::Relaxed);
self.caches[allocation.slot_id].clear();
}
}
pub fn available_slots(&self) -> usize {
self.free_slots.read().len()
}
pub fn available_blocks(&self) -> usize {
self.free_blocks.read().len()
}
pub fn can_allocate(&self, max_tokens: usize) -> bool {
let slots_available = !self.free_slots.read().is_empty();
let blocks_needed = self.config.blocks_for_seq_len(max_tokens);
let blocks_available = self.free_blocks.read().len() >= blocks_needed;
slots_available && blocks_available
}
pub fn get_allocation(&self, request_id: RequestId) -> Option<KvCacheAllocation> {
self.allocations.read().get(&request_id).cloned()
}
pub fn get_block_table(&self, request_id: RequestId) -> Option<Vec<usize>> {
self.allocations
.read()
.get(&request_id)
.map(|a| a.block_table.clone())
}
pub fn set_length(&mut self, request_id: RequestId, length: usize) -> Result<()> {
let mut allocations = self.allocations.write();
let allocation = allocations.get_mut(&request_id).ok_or_else(|| {
RuvLLMError::NotFound(format!("No allocation for request {}", request_id))
})?;
allocation.current_length = length;
Ok(())
}
pub fn swap_out(&mut self, request_id: RequestId) -> Result<()> {
let allocation = {
let allocations = self.allocations.read();
allocations.get(&request_id).cloned().ok_or_else(|| {
RuvLLMError::NotFound(format!("No allocation for request {}", request_id))
})?
};
let (keys, values) = self.caches[allocation.slot_id].get_all_kv();
let swapped = SwappedCache {
request_id,
original_slot: allocation.slot_id,
keys,
values,
seq_len: allocation.current_length,
block_table: allocation.block_table.clone(),
};
self.swap_space.write().insert(request_id, swapped);
self.caches[allocation.slot_id].clear();
self.free_slots.write().push_back(allocation.slot_id);
let mut free_blocks = self.free_blocks.write();
for block in &allocation.block_table {
free_blocks.push_back(*block);
}
if let Some(alloc) = self.allocations.write().get_mut(&request_id) {
alloc.is_active = false;
}
Ok(())
}
pub fn swap_in(&mut self, request_id: RequestId) -> Result<usize> {
let swapped = self.swap_space.write().remove(&request_id).ok_or_else(|| {
RuvLLMError::NotFound(format!("No swapped cache for request {}", request_id))
})?;
let slot_id = {
let mut free_slots = self.free_slots.write();
free_slots
.pop_front()
.ok_or_else(|| RuvLLMError::OutOfMemory("No free slots for swap in".to_string()))?
};
let blocks_needed = self.config.blocks_for_seq_len(swapped.seq_len);
let block_table = {
let mut free_blocks = self.free_blocks.write();
if free_blocks.len() < blocks_needed {
self.free_slots.write().push_front(slot_id);
return Err(RuvLLMError::OutOfMemory(
"Not enough blocks for swap in".to_string(),
));
}
(0..blocks_needed)
.filter_map(|_| free_blocks.pop_front())
.collect::<Vec<_>>()
};
self.caches[slot_id].append(&swapped.keys, &swapped.values)?;
if let Some(alloc) = self.allocations.write().get_mut(&request_id) {
alloc.slot_id = slot_id;
alloc.block_table = block_table;
alloc.num_blocks = blocks_needed;
alloc.is_active = true;
}
Ok(slot_id)
}
pub fn is_swapped(&self, request_id: RequestId) -> bool {
self.swap_space.read().contains_key(&request_id)
}
pub fn stats(&self) -> KvCacheManagerStats {
KvCacheManagerStats {
total_slots: self.config.num_slots,
free_slots: self.available_slots(),
active_allocations: self.active_allocations.load(Ordering::Relaxed),
total_blocks: self.config.total_blocks,
free_blocks: self.available_blocks(),
allocated_blocks: self.allocated_blocks.load(Ordering::Relaxed),
swapped_requests: self.swap_space.read().len(),
block_size: self.config.block_size,
bytes_per_block: self.config.bytes_per_block(),
total_memory: self.config.total_memory(),
}
}
pub fn get_cache(&self, slot_id: usize) -> Option<&Arc<TwoTierKvCache>> {
self.caches.get(slot_id)
}
pub fn config(&self) -> &KvCachePoolConfig {
&self.config
}
}
#[derive(Debug, Clone, Default)]
pub struct KvCacheManagerStats {
pub total_slots: usize,
pub free_slots: usize,
pub active_allocations: usize,
pub total_blocks: usize,
pub free_blocks: usize,
pub allocated_blocks: usize,
pub swapped_requests: usize,
pub block_size: usize,
pub bytes_per_block: usize,
pub total_memory: usize,
}
impl KvCacheManagerStats {
pub fn slot_utilization(&self) -> f64 {
if self.total_slots > 0 {
self.active_allocations as f64 / self.total_slots as f64
} else {
0.0
}
}
pub fn block_utilization(&self) -> f64 {
if self.total_blocks > 0 {
self.allocated_blocks as f64 / self.total_blocks as f64
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_manager() -> KvCacheManager {
let config = KvCachePoolConfig {
num_slots: 4,
max_seq_len: 128,
block_size: 16,
total_blocks: 32,
num_kv_heads: 2,
head_dim: 64,
num_layers: 4,
};
KvCacheManager::new(config)
}
#[test]
fn test_allocation() {
let mut manager = create_test_manager();
let request_id = RequestId::new();
let slot = manager.allocate(request_id, 64).unwrap();
assert!(slot < 4);
let allocation = manager.get_allocation(request_id).unwrap();
assert_eq!(allocation.slot_id, slot);
assert_eq!(allocation.max_length, 64);
assert_eq!(allocation.current_length, 0);
}
#[test]
fn test_extend() {
let mut manager = create_test_manager();
let request_id = RequestId::new();
manager.allocate(request_id, 64).unwrap();
manager.extend(request_id, 32).unwrap();
let allocation = manager.get_allocation(request_id).unwrap();
assert_eq!(allocation.current_length, 32);
}
#[test]
fn test_free() {
let mut manager = create_test_manager();
let request_id = RequestId::new();
let initial_slots = manager.available_slots();
manager.allocate(request_id, 64).unwrap();
assert_eq!(manager.available_slots(), initial_slots - 1);
manager.free(request_id);
assert_eq!(manager.available_slots(), initial_slots);
assert!(manager.get_allocation(request_id).is_none());
}
#[test]
fn test_out_of_slots() {
let mut manager = create_test_manager();
for i in 0..4 {
let id = RequestId::from_uuid(uuid::Uuid::from_u128(i as u128));
manager.allocate(id, 32).unwrap();
}
let result = manager.allocate(RequestId::new(), 32);
assert!(result.is_err());
}
#[test]
fn test_can_allocate() {
let mut manager = create_test_manager();
assert!(manager.can_allocate(64));
for i in 0..4 {
let id = RequestId::from_uuid(uuid::Uuid::from_u128(i as u128));
manager.allocate(id, 32).unwrap();
}
assert!(!manager.can_allocate(64));
}
#[test]
fn test_stats() {
let mut manager = create_test_manager();
let request_id = RequestId::new();
manager.allocate(request_id, 64).unwrap();
let stats = manager.stats();
assert_eq!(stats.total_slots, 4);
assert_eq!(stats.free_slots, 3);
assert_eq!(stats.active_allocations, 1);
}
}