use candle_cuda_vmm::{SharedMemoryPool, Result};
use candle_core::Device;
use std::collections::HashMap;
struct ModelKVCache {
model_id: String,
max_tokens: usize,
bytes_per_token: usize,
current_tokens: usize,
}
impl ModelKVCache {
fn new(model_id: String, max_tokens: usize, bytes_per_token: usize) -> Self {
Self {
model_id,
max_tokens,
bytes_per_token,
current_tokens: 0,
}
}
fn tokens_to_bytes(&self, tokens: usize) -> usize {
tokens * self.bytes_per_token
}
}
struct InferenceEngine {
shared_pool: SharedMemoryPool,
models: HashMap<String, ModelKVCache>,
}
impl InferenceEngine {
fn new(total_gpu_memory: usize, device: Device) -> Result<Self> {
let shared_pool = SharedMemoryPool::new(total_gpu_memory, device)?;
Ok(Self {
shared_pool,
models: HashMap::new(),
})
}
fn load_model(
&mut self,
model_id: &str,
virtual_capacity: usize,
max_tokens: usize,
bytes_per_token: usize,
) -> Result<()> {
println!("Loading model: {}", model_id);
println!(" Virtual capacity: {} GB", virtual_capacity / (1024 * 1024 * 1024));
println!(" Max tokens: {}", max_tokens);
println!(" Bytes per token: {}", bytes_per_token);
self.shared_pool.register_model(model_id, virtual_capacity)?;
let cache = ModelKVCache::new(model_id.to_string(), max_tokens, bytes_per_token);
self.models.insert(model_id.to_string(), cache);
println!(" ✓ Model loaded successfully\n");
Ok(())
}
fn process_request(&mut self, model_id: &str, num_tokens: usize) -> Result<()> {
let cache = self.models.get_mut(model_id)
.ok_or_else(|| candle_cuda_vmm::VmmError::ModelNotFound(model_id.to_string()))?;
if cache.current_tokens + num_tokens > cache.max_tokens {
return Err(candle_cuda_vmm::VmmError::other("Token limit exceeded"));
}
let size = cache.tokens_to_bytes(num_tokens);
let _addr = self.shared_pool.allocate_for_model(model_id, size)?;
cache.current_tokens += num_tokens;
Ok(())
}
fn print_stats(&self) {
println!("=== Inference Engine Statistics ===");
let global_stats = self.shared_pool.global_stats();
println!("\nGlobal Memory:");
println!(" Physical limit: {} GB", global_stats.physical_limit / (1024 * 1024 * 1024));
println!(" Physical usage: {} MB ({:.1}%)",
global_stats.physical_usage / (1024 * 1024),
(global_stats.physical_usage as f32 / global_stats.physical_limit as f32) * 100.0
);
println!(" Active models: {}", global_stats.num_models);
println!("\nPer-Model Statistics:");
for (model_id, cache) in &self.models {
if let Some(stats) = self.shared_pool.get_model_stats(model_id) {
println!(" {}:", model_id);
println!(" Virtual capacity: {} GB", stats.virtual_capacity / (1024 * 1024 * 1024));
println!(" Physical usage: {} MB", stats.physical_usage / (1024 * 1024));
println!(" Mapped pages: {}", stats.mapped_pages);
println!(" Current tokens: {}/{}", cache.current_tokens, cache.max_tokens);
}
}
println!();
}
}
fn main() -> Result<()> {
println!("=== KV Cache Simulation for Multi-Model LLM Serving ===\n");
let device = match Device::new_cuda(0) {
Ok(d) => d,
Err(e) => {
eprintln!("Error: CUDA device not available: {}", e);
eprintln!("This example requires a CUDA-capable GPU.");
return Ok(());
}
};
let physical_limit = 16 * 1024 * 1024 * 1024u64; println!("Creating inference engine with {} GB physical memory limit\n",
physical_limit / (1024 * 1024 * 1024));
let mut engine = InferenceEngine::new(physical_limit as usize, device)?;
engine.load_model(
"llama-7b",
64 * 1024 * 1024 * 1024, 8192, 512 * 1024, )?;
engine.load_model(
"gpt2",
32 * 1024 * 1024 * 1024, 4096, 256 * 1024, )?;
engine.load_model(
"mistral-7b",
64 * 1024 * 1024 * 1024, 8192, 512 * 1024, )?;
engine.print_stats();
println!("=== Simulating Inference Workload ===\n");
println!("Request 1: llama-7b (100 tokens)");
engine.process_request("llama-7b", 100)?;
println!(" ✓ Processed successfully\n");
println!("Request 2: gpt2 (50 tokens)");
engine.process_request("gpt2", 50)?;
println!(" ✓ Processed successfully\n");
println!("Request 3: mistral-7b (150 tokens)");
engine.process_request("mistral-7b", 150)?;
println!(" ✓ Processed successfully\n");
println!("Request 4: llama-7b (200 more tokens)");
engine.process_request("llama-7b", 200)?;
println!(" ✓ Processed successfully\n");
engine.print_stats();
println!("=== Benefits of Elastic KV Cache ===");
println!("✓ Only allocated memory for actual tokens (not max capacity)");
println!("✓ Multiple models share physical memory pool");
println!("✓ Each model has large virtual address space");
println!("✓ Memory allocated on-demand as requests arrive");
println!("✓ Reduced time-to-first-token vs static pre-allocation\n");
println!("Simulation completed successfully!");
Ok(())
}