Skip to main content

oxide_rs/inference/
prefix_cache.rs

1//! Prefix Caching for LLM Inference
2//!
3//! Hash-based KV cache for repeated system prompts. Dramatically reduces
4//! Time to First Token (TTFT) for API workloads with repeated system prompts.
5
6use std::collections::hash_map::DefaultHasher;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10
11use candle_core::Tensor;
12use sha2::{Digest, Sha256};
13
14pub struct PrefixCacheConfig {
15    pub memory_budget_mb: usize,
16    pub enabled: bool,
17}
18
19impl Default for PrefixCacheConfig {
20    fn default() -> Self {
21        Self {
22            memory_budget_mb: 512,
23            enabled: true,
24        }
25    }
26}
27
28impl Clone for PrefixCacheConfig {
29    fn clone(&self) -> Self {
30        Self {
31            memory_budget_mb: self.memory_budget_mb,
32            enabled: self.enabled,
33        }
34    }
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct CacheKey {
39    pub prompt_hash: u64,
40    pub system_hash: u64,
41    pub model_config_hash: u64,
42}
43
44impl CacheKey {
45    pub fn new(prompt: &str, system_prompt: Option<&str>, model_config: &str) -> Self {
46        let prompt_hash = Self::hash_string(prompt);
47        let system_hash = Self::hash_string(system_prompt.unwrap_or(""));
48        let model_config_hash = Self::hash_string(model_config);
49
50        Self {
51            prompt_hash,
52            system_hash,
53            model_config_hash,
54        }
55    }
56
57    fn hash_string(s: &str) -> u64 {
58        let mut hasher = DefaultHasher::new();
59        s.hash(&mut hasher);
60        hasher.finish()
61    }
62}
63
64pub struct CachedPrefix {
65    pub key: CacheKey,
66    pub tokens: Vec<u32>,
67    pub kv_cache: Vec<CachedLayer>,
68    pub access_count: u64,
69    pub last_access: std::time::Instant,
70}
71
72pub struct CachedLayer {
73    pub k_cache: Tensor,
74    pub v_cache: Tensor,
75}
76
77pub struct PrefixCache {
78    config: PrefixCacheConfig,
79    cache: HashMap<CacheKey, Arc<CachedPrefix>>,
80    access_order: Vec<CacheKey>,
81    current_memory_bytes: usize,
82    memory_budget_bytes: usize,
83}
84
85impl PrefixCache {
86    pub fn new(config: PrefixCacheConfig) -> Self {
87        let memory_budget_bytes = config.memory_budget_mb * 1024 * 1024;
88        Self {
89            config,
90            cache: HashMap::new(),
91            access_order: Vec::new(),
92            current_memory_bytes: 0,
93            memory_budget_bytes,
94        }
95    }
96
97    pub fn config(&self) -> &PrefixCacheConfig {
98        &self.config
99    }
100
101    pub fn is_enabled(&self) -> bool {
102        self.config.enabled
103    }
104
105    pub fn get(&self, key: &CacheKey) -> Option<Arc<CachedPrefix>> {
106        if !self.config.enabled {
107            return None;
108        }
109
110        self.cache.get(key).cloned()
111    }
112
113    pub fn insert(&mut self, key: CacheKey, tokens: Vec<u32>, _kv_cache: Vec<CachedLayer>) {
114        if !self.config.enabled {
115            return;
116        }
117
118        let estimated_size = tokens.len() * 4 + 1024;
119
120        while self.current_memory_bytes + estimated_size > self.memory_budget_bytes
121            && !self.access_order.is_empty()
122        {
123            self.evict_lru();
124        }
125
126        if self.current_memory_bytes + estimated_size > self.memory_budget_bytes {
127            tracing::warn!("Prefix cache: prompt too large to cache");
128            return;
129        }
130
131        let prefix = Arc::new(CachedPrefix {
132            key: key.clone(),
133            tokens,
134            kv_cache: Vec::new(),
135            access_count: 1,
136            last_access: std::time::Instant::now(),
137        });
138
139        self.current_memory_bytes += estimated_size;
140        self.cache.insert(key.clone(), prefix);
141        self.access_order.push(key);
142    }
143
144    pub fn touch(&mut self, key: &CacheKey) {
145        // Just move to back of access order for LRU
146        if let Some(pos) = self.access_order.iter().position(|k| k == key) {
147            self.access_order.remove(pos);
148            self.access_order.push(key.clone());
149        }
150    }
151
152    fn evict_lru(&mut self) {
153        if let Some(oldest_key) = self.access_order.first().cloned() {
154            if let Some(prefix) = self.cache.remove(&oldest_key) {
155                let size = prefix.tokens.len() * 4 + 1024;
156                self.current_memory_bytes = self.current_memory_bytes.saturating_sub(size);
157            }
158            self.access_order.remove(0);
159        }
160    }
161
162    pub fn clear(&mut self) {
163        self.cache.clear();
164        self.access_order.clear();
165        self.current_memory_bytes = 0;
166    }
167
168    pub fn stats(&self) -> PrefixCacheStats {
169        PrefixCacheStats {
170            num_entries: self.cache.len(),
171            memory_used_mb: self.current_memory_bytes / (1024 * 1024),
172            memory_budget_mb: self.config.memory_budget_mb,
173            hit_rate: 0.0,
174        }
175    }
176}
177
178pub struct PrefixCacheStats {
179    pub num_entries: usize,
180    pub memory_used_mb: usize,
181    pub memory_budget_mb: usize,
182    pub hit_rate: f64,
183}
184
185pub fn hash_prompt(prompt: &str) -> u64 {
186    let mut hasher = Sha256::new();
187    hasher.update(prompt.as_bytes());
188    let result = hasher.finalize();
189    u64::from_le_bytes(result[0..8].try_into().unwrap())
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_cache_key_creation() {
198        let key1 = CacheKey::new("Hello", Some("System"), "config");
199        let key2 = CacheKey::new("Hello", Some("System"), "config");
200        let key3 = CacheKey::new("World", Some("System"), "config");
201
202        assert_eq!(key1, key2);
203        assert_ne!(key1, key3);
204    }
205
206    #[test]
207    fn test_prefix_cache_insert() {
208        let config = PrefixCacheConfig::default();
209        let mut cache = PrefixCache::new(config);
210
211        let key = CacheKey::new("test prompt", Some("system"), "config");
212        cache.insert(key, vec![1, 2, 3, 4], Vec::new());
213
214        assert_eq!(cache.stats().num_entries, 1);
215    }
216}