candle-cuda-vmm 0.1.1

CUDA Virtual Memory Management bindings for elastic KV cache allocation in Candle
Documentation
//! KV cache simulation example for LLM inference.
//!
//! This example simulates how candle-cuda-vmm can be used for elastic KV cache
//! allocation in an LLM inference engine like Lightbulb.

use candle_cuda_vmm::{SharedMemoryPool, Result};
use candle_core::Device;
use std::collections::HashMap;

/// Simulated KV cache for a model.
struct ModelKVCache {
    model_id: String,
    max_tokens: usize,
    bytes_per_token: usize,
    current_tokens: usize,
}

impl ModelKVCache {
    fn new(model_id: String, max_tokens: usize, bytes_per_token: usize) -> Self {
        Self {
            model_id,
            max_tokens,
            bytes_per_token,
            current_tokens: 0,
        }
    }

    fn tokens_to_bytes(&self, tokens: usize) -> usize {
        tokens * self.bytes_per_token
    }
}

/// Simulated multi-model inference engine.
struct InferenceEngine {
    shared_pool: SharedMemoryPool,
    models: HashMap<String, ModelKVCache>,
}

impl InferenceEngine {
    fn new(total_gpu_memory: usize, device: Device) -> Result<Self> {
        let shared_pool = SharedMemoryPool::new(total_gpu_memory, device)?;
        Ok(Self {
            shared_pool,
            models: HashMap::new(),
        })
    }

    fn load_model(
        &mut self,
        model_id: &str,
        virtual_capacity: usize,
        max_tokens: usize,
        bytes_per_token: usize,
    ) -> Result<()> {
        println!("Loading model: {}", model_id);
        println!("  Virtual capacity: {} GB", virtual_capacity / (1024 * 1024 * 1024));
        println!("  Max tokens: {}", max_tokens);
        println!("  Bytes per token: {}", bytes_per_token);

        self.shared_pool.register_model(model_id, virtual_capacity)?;
        
        let cache = ModelKVCache::new(model_id.to_string(), max_tokens, bytes_per_token);
        self.models.insert(model_id.to_string(), cache);

        println!("  ✓ Model loaded successfully\n");
        Ok(())
    }

    fn process_request(&mut self, model_id: &str, num_tokens: usize) -> Result<()> {
        let cache = self.models.get_mut(model_id)
            .ok_or_else(|| candle_cuda_vmm::VmmError::ModelNotFound(model_id.to_string()))?;

        if cache.current_tokens + num_tokens > cache.max_tokens {
            return Err(candle_cuda_vmm::VmmError::other("Token limit exceeded"));
        }

        // Allocate KV cache for new tokens
        let size = cache.tokens_to_bytes(num_tokens);
        let _addr = self.shared_pool.allocate_for_model(model_id, size)?;

        cache.current_tokens += num_tokens;

        Ok(())
    }

    fn print_stats(&self) {
        println!("=== Inference Engine Statistics ===");
        
        let global_stats = self.shared_pool.global_stats();
        println!("\nGlobal Memory:");
        println!("  Physical limit: {} GB", global_stats.physical_limit / (1024 * 1024 * 1024));
        println!("  Physical usage: {} MB ({:.1}%)", 
            global_stats.physical_usage / (1024 * 1024),
            (global_stats.physical_usage as f32 / global_stats.physical_limit as f32) * 100.0
        );
        println!("  Active models: {}", global_stats.num_models);

        println!("\nPer-Model Statistics:");
        for (model_id, cache) in &self.models {
            if let Some(stats) = self.shared_pool.get_model_stats(model_id) {
                println!("  {}:", model_id);
                println!("    Virtual capacity: {} GB", stats.virtual_capacity / (1024 * 1024 * 1024));
                println!("    Physical usage: {} MB", stats.physical_usage / (1024 * 1024));
                println!("    Mapped pages: {}", stats.mapped_pages);
                println!("    Current tokens: {}/{}", cache.current_tokens, cache.max_tokens);
            }
        }
        println!();
    }
}

fn main() -> Result<()> {
    println!("=== KV Cache Simulation for Multi-Model LLM Serving ===\n");

    // Check if CUDA is available
    let device = match Device::new_cuda(0) {
        Ok(d) => d,
        Err(e) => {
            eprintln!("Error: CUDA device not available: {}", e);
            eprintln!("This example requires a CUDA-capable GPU.");
            return Ok(());
        }
    };

    // Create inference engine with 16GB physical memory limit
    let physical_limit = 16 * 1024 * 1024 * 1024u64; // 16GB
    println!("Creating inference engine with {} GB physical memory limit\n", 
        physical_limit / (1024 * 1024 * 1024));

    let mut engine = InferenceEngine::new(physical_limit as usize, device)?;

    // Load multiple models with different sizes
    // Each model gets large virtual capacity but shares physical memory
    
    // LLaMA-7B: 64GB virtual capacity
    engine.load_model(
        "llama-7b",
        64 * 1024 * 1024 * 1024,  // 64GB virtual
        8192,                     // max tokens
        512 * 1024,               // 512KB per token (simplified)
    )?;

    // GPT-2: 32GB virtual capacity
    engine.load_model(
        "gpt2",
        32 * 1024 * 1024 * 1024,  // 32GB virtual
        4096,                     // max tokens
        256 * 1024,               // 256KB per token (simplified)
    )?;

    // Mistral-7B: 64GB virtual capacity
    engine.load_model(
        "mistral-7b",
        64 * 1024 * 1024 * 1024,  // 64GB virtual
        8192,                     // max tokens
        512 * 1024,               // 512KB per token (simplified)
    )?;

    engine.print_stats();

    // Simulate inference workload
    println!("=== Simulating Inference Workload ===\n");

    // Request 1: LLaMA-7B with 100 tokens
    println!("Request 1: llama-7b (100 tokens)");
    engine.process_request("llama-7b", 100)?;
    println!("  ✓ Processed successfully\n");

    // Request 2: GPT-2 with 50 tokens
    println!("Request 2: gpt2 (50 tokens)");
    engine.process_request("gpt2", 50)?;
    println!("  ✓ Processed successfully\n");

    // Request 3: Mistral-7B with 150 tokens
    println!("Request 3: mistral-7b (150 tokens)");
    engine.process_request("mistral-7b", 150)?;
    println!("  ✓ Processed successfully\n");

    // Request 4: LLaMA-7B with more tokens
    println!("Request 4: llama-7b (200 more tokens)");
    engine.process_request("llama-7b", 200)?;
    println!("  ✓ Processed successfully\n");

    engine.print_stats();

    println!("=== Benefits of Elastic KV Cache ===");
    println!("✓ Only allocated memory for actual tokens (not max capacity)");
    println!("✓ Multiple models share physical memory pool");
    println!("✓ Each model has large virtual address space");
    println!("✓ Memory allocated on-demand as requests arrive");
    println!("✓ Reduced time-to-first-token vs static pre-allocation\n");

    println!("Simulation completed successfully!");

    Ok(())
}