oxide_rs/inference/
prefix_cache.rs1use std::collections::hash_map::DefaultHasher;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10
11use candle_core::Tensor;
12use sha2::{Digest, Sha256};
13
14pub struct PrefixCacheConfig {
15 pub memory_budget_mb: usize,
16 pub enabled: bool,
17}
18
19impl Default for PrefixCacheConfig {
20 fn default() -> Self {
21 Self {
22 memory_budget_mb: 512,
23 enabled: true,
24 }
25 }
26}
27
28impl Clone for PrefixCacheConfig {
29 fn clone(&self) -> Self {
30 Self {
31 memory_budget_mb: self.memory_budget_mb,
32 enabled: self.enabled,
33 }
34 }
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct CacheKey {
39 pub prompt_hash: u64,
40 pub system_hash: u64,
41 pub model_config_hash: u64,
42}
43
44impl CacheKey {
45 pub fn new(prompt: &str, system_prompt: Option<&str>, model_config: &str) -> Self {
46 let prompt_hash = Self::hash_string(prompt);
47 let system_hash = Self::hash_string(system_prompt.unwrap_or(""));
48 let model_config_hash = Self::hash_string(model_config);
49
50 Self {
51 prompt_hash,
52 system_hash,
53 model_config_hash,
54 }
55 }
56
57 fn hash_string(s: &str) -> u64 {
58 let mut hasher = DefaultHasher::new();
59 s.hash(&mut hasher);
60 hasher.finish()
61 }
62}
63
64pub struct CachedPrefix {
65 pub key: CacheKey,
66 pub tokens: Vec<u32>,
67 pub kv_cache: Vec<CachedLayer>,
68 pub access_count: u64,
69 pub last_access: std::time::Instant,
70}
71
72pub struct CachedLayer {
73 pub k_cache: Tensor,
74 pub v_cache: Tensor,
75}
76
77pub struct PrefixCache {
78 config: PrefixCacheConfig,
79 cache: HashMap<CacheKey, Arc<CachedPrefix>>,
80 access_order: Vec<CacheKey>,
81 current_memory_bytes: usize,
82 memory_budget_bytes: usize,
83}
84
85impl PrefixCache {
86 pub fn new(config: PrefixCacheConfig) -> Self {
87 let memory_budget_bytes = config.memory_budget_mb * 1024 * 1024;
88 Self {
89 config,
90 cache: HashMap::new(),
91 access_order: Vec::new(),
92 current_memory_bytes: 0,
93 memory_budget_bytes,
94 }
95 }
96
97 pub fn config(&self) -> &PrefixCacheConfig {
98 &self.config
99 }
100
101 pub fn is_enabled(&self) -> bool {
102 self.config.enabled
103 }
104
105 pub fn get(&self, key: &CacheKey) -> Option<Arc<CachedPrefix>> {
106 if !self.config.enabled {
107 return None;
108 }
109
110 self.cache.get(key).cloned()
111 }
112
113 pub fn insert(&mut self, key: CacheKey, tokens: Vec<u32>, _kv_cache: Vec<CachedLayer>) {
114 if !self.config.enabled {
115 return;
116 }
117
118 let estimated_size = tokens.len() * 4 + 1024;
119
120 while self.current_memory_bytes + estimated_size > self.memory_budget_bytes
121 && !self.access_order.is_empty()
122 {
123 self.evict_lru();
124 }
125
126 if self.current_memory_bytes + estimated_size > self.memory_budget_bytes {
127 tracing::warn!("Prefix cache: prompt too large to cache");
128 return;
129 }
130
131 let prefix = Arc::new(CachedPrefix {
132 key: key.clone(),
133 tokens,
134 kv_cache: Vec::new(),
135 access_count: 1,
136 last_access: std::time::Instant::now(),
137 });
138
139 self.current_memory_bytes += estimated_size;
140 self.cache.insert(key.clone(), prefix);
141 self.access_order.push(key);
142 }
143
144 pub fn touch(&mut self, key: &CacheKey) {
145 if let Some(pos) = self.access_order.iter().position(|k| k == key) {
147 self.access_order.remove(pos);
148 self.access_order.push(key.clone());
149 }
150 }
151
152 fn evict_lru(&mut self) {
153 if let Some(oldest_key) = self.access_order.first().cloned() {
154 if let Some(prefix) = self.cache.remove(&oldest_key) {
155 let size = prefix.tokens.len() * 4 + 1024;
156 self.current_memory_bytes = self.current_memory_bytes.saturating_sub(size);
157 }
158 self.access_order.remove(0);
159 }
160 }
161
162 pub fn clear(&mut self) {
163 self.cache.clear();
164 self.access_order.clear();
165 self.current_memory_bytes = 0;
166 }
167
168 pub fn stats(&self) -> PrefixCacheStats {
169 PrefixCacheStats {
170 num_entries: self.cache.len(),
171 memory_used_mb: self.current_memory_bytes / (1024 * 1024),
172 memory_budget_mb: self.config.memory_budget_mb,
173 hit_rate: 0.0,
174 }
175 }
176}
177
178pub struct PrefixCacheStats {
179 pub num_entries: usize,
180 pub memory_used_mb: usize,
181 pub memory_budget_mb: usize,
182 pub hit_rate: f64,
183}
184
185pub fn hash_prompt(prompt: &str) -> u64 {
186 let mut hasher = Sha256::new();
187 hasher.update(prompt.as_bytes());
188 let result = hasher.finalize();
189 u64::from_le_bytes(result[0..8].try_into().unwrap())
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_cache_key_creation() {
198 let key1 = CacheKey::new("Hello", Some("System"), "config");
199 let key2 = CacheKey::new("Hello", Some("System"), "config");
200 let key3 = CacheKey::new("World", Some("System"), "config");
201
202 assert_eq!(key1, key2);
203 assert_ne!(key1, key3);
204 }
205
206 #[test]
207 fn test_prefix_cache_insert() {
208 let config = PrefixCacheConfig::default();
209 let mut cache = PrefixCache::new(config);
210
211 let key = CacheKey::new("test prompt", Some("system"), "config");
212 cache.insert(key, vec![1, 2, 3, 4], Vec::new());
213
214 assert_eq!(cache.stats().num_entries, 1);
215 }
216}