1use std::collections::HashMap;
10use std::hash::{Hash, Hasher};
11
12use crate::tensor::Tensor;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct PrefixId(pub u64);
17
18impl PrefixId {
19 pub fn from_tokens(tokens: &[u32]) -> Self {
21 use std::collections::hash_map::DefaultHasher;
22 let mut hasher = DefaultHasher::new();
23 tokens.hash(&mut hasher);
24 PrefixId(hasher.finish())
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct CachedPrefix {
31 pub tokens: Vec<u32>,
33 pub k_cache: Vec<Tensor>,
35 pub v_cache: Vec<Tensor>,
37 pub seq_len: usize,
39 pub ref_count: usize,
41 pub last_access: std::time::Instant,
43}
44
45impl CachedPrefix {
46 pub fn new(tokens: Vec<u32>, k_cache: Vec<Tensor>, v_cache: Vec<Tensor>) -> Self {
48 let seq_len = tokens.len();
49 Self {
50 tokens,
51 k_cache,
52 v_cache,
53 seq_len,
54 ref_count: 0,
55 last_access: std::time::Instant::now(),
56 }
57 }
58
59 pub fn memory_size(&self) -> usize {
61 let k_size: usize = self.k_cache.iter().map(|t| t.data().len()).sum();
62 let v_size: usize = self.v_cache.iter().map(|t| t.data().len()).sum();
63 k_size + v_size + self.tokens.len() * 4
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct PromptCacheConfig {
70 pub max_entries: usize,
72 pub max_memory: usize,
74 pub min_prefix_len: usize,
76 pub cache_system_prompts: bool,
78}
79
80impl Default for PromptCacheConfig {
81 fn default() -> Self {
82 Self {
83 max_entries: 100,
84 max_memory: 1024 * 1024 * 1024, min_prefix_len: 32,
86 cache_system_prompts: true,
87 }
88 }
89}
90
91pub struct PromptCache {
93 config: PromptCacheConfig,
95 entries: HashMap<PrefixId, CachedPrefix>,
97 memory_used: usize,
99}
100
101impl PromptCache {
102 pub fn new(config: PromptCacheConfig) -> Self {
104 Self {
105 config,
106 entries: HashMap::new(),
107 memory_used: 0,
108 }
109 }
110
111 pub fn cache_prefix(
113 &mut self,
114 tokens: &[u32],
115 k_cache: Vec<Tensor>,
116 v_cache: Vec<Tensor>,
117 ) -> PrefixId {
118 let id = PrefixId::from_tokens(tokens);
119
120 if self.entries.contains_key(&id) {
122 if let Some(entry) = self.entries.get_mut(&id) {
123 entry.ref_count += 1;
124 entry.last_access = std::time::Instant::now();
125 }
126 return id;
127 }
128
129 if tokens.len() < self.config.min_prefix_len {
131 return id;
132 }
133
134 let prefix = CachedPrefix::new(tokens.to_vec(), k_cache, v_cache);
135 let size = prefix.memory_size();
136
137 while self.memory_used + size > self.config.max_memory
139 || self.entries.len() >= self.config.max_entries
140 {
141 if !self.evict_lru() {
142 break;
143 }
144 }
145
146 self.memory_used += size;
147 self.entries.insert(id.clone(), prefix);
148
149 id
150 }
151
152 pub fn get_prefix(&mut self, id: &PrefixId) -> Option<&CachedPrefix> {
154 if let Some(entry) = self.entries.get_mut(id) {
155 entry.ref_count += 1;
156 entry.last_access = std::time::Instant::now();
157 Some(entry)
158 } else {
159 None
160 }
161 }
162
163 pub fn find_matching_prefix(&mut self, tokens: &[u32]) -> Option<(PrefixId, usize)> {
165 let mut best_match: Option<(PrefixId, usize)> = None;
166
167 for (id, entry) in &self.entries {
168 if tokens.len() >= entry.tokens.len()
170 && tokens[..entry.tokens.len()] == entry.tokens[..]
171 {
172 let match_len = entry.tokens.len();
173 if best_match.is_none() || match_len > best_match.as_ref().unwrap().1 {
174 best_match = Some((id.clone(), match_len));
175 }
176 }
177 }
178
179 if let Some((ref id, _)) = best_match
181 && let Some(entry) = self.entries.get_mut(id)
182 {
183 entry.last_access = std::time::Instant::now();
184 entry.ref_count += 1;
185 }
186
187 best_match
188 }
189
190 pub fn remove_prefix(&mut self, id: &PrefixId) {
192 if let Some(entry) = self.entries.remove(id) {
193 self.memory_used = self.memory_used.saturating_sub(entry.memory_size());
194 }
195 }
196
197 pub fn clear(&mut self) {
199 self.entries.clear();
200 self.memory_used = 0;
201 }
202
203 pub fn stats(&self) -> PromptCacheStats {
205 PromptCacheStats {
206 num_entries: self.entries.len(),
207 memory_used: self.memory_used,
208 total_tokens_cached: self.entries.values().map(|e| e.seq_len).sum(),
209 }
210 }
211
212 fn evict_lru(&mut self) -> bool {
214 let lru_id = self
216 .entries
217 .iter()
218 .filter(|(_, e)| e.ref_count == 0)
219 .min_by_key(|(_, e)| e.last_access)
220 .map(|(id, _)| id.clone());
221
222 if let Some(id) = lru_id {
223 self.remove_prefix(&id);
224 true
225 } else {
226 false
227 }
228 }
229
230 pub fn release_prefix(&mut self, id: &PrefixId) {
232 if let Some(entry) = self.entries.get_mut(id) {
233 entry.ref_count = entry.ref_count.saturating_sub(1);
234 }
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct PromptCacheStats {
241 pub num_entries: usize,
243 pub memory_used: usize,
245 pub total_tokens_cached: usize,
247}
248
249pub struct PrefixSharing {
251 cache: PromptCache,
253 active_prefix: Option<PrefixId>,
255}
256
257impl PrefixSharing {
258 pub fn new(config: PromptCacheConfig) -> Self {
260 Self {
261 cache: PromptCache::new(config),
262 active_prefix: None,
263 }
264 }
265
266 pub fn try_restore(
270 &mut self,
271 tokens: &[u32],
272 k_cache: &mut [Tensor],
273 v_cache: &mut [Tensor],
274 ) -> usize {
275 let (id, match_len) = match self.cache.find_matching_prefix(tokens) {
277 Some(m) => m,
278 None => return 0,
279 };
280
281 let prefix = match self.cache.get_prefix(&id) {
283 Some(p) => p,
284 None => return 0,
285 };
286
287 for (layer_idx, (cached_k, cached_v)) in
289 prefix.k_cache.iter().zip(prefix.v_cache.iter()).enumerate()
290 {
291 if layer_idx < k_cache.len() {
292 let k_src = cached_k.data();
294 let v_src = cached_v.data();
295
296 if let Some(k_dst) = k_cache[layer_idx].data_mut() {
297 let copy_len = k_src.len().min(k_dst.len());
298 k_dst[..copy_len].copy_from_slice(&k_src[..copy_len]);
299 }
300
301 if let Some(v_dst) = v_cache[layer_idx].data_mut() {
302 let copy_len = v_src.len().min(v_dst.len());
303 v_dst[..copy_len].copy_from_slice(&v_src[..copy_len]);
304 }
305 }
306 }
307
308 self.active_prefix = Some(id);
309 match_len
310 }
311
312 pub fn save_prefix(
314 &mut self,
315 tokens: &[u32],
316 k_cache: &[Tensor],
317 v_cache: &[Tensor],
318 ) -> PrefixId {
319 let k_cloned: Vec<Tensor> = k_cache.to_vec();
321 let v_cloned: Vec<Tensor> = v_cache.to_vec();
322
323 let id = self.cache.cache_prefix(tokens, k_cloned, v_cloned);
324 self.active_prefix = Some(id.clone());
325 id
326 }
327
328 pub fn release_active(&mut self) {
330 if let Some(id) = self.active_prefix.take() {
331 self.cache.release_prefix(&id);
332 }
333 }
334
335 pub fn stats(&self) -> PromptCacheStats {
337 self.cache.stats()
338 }
339
340 pub fn clear(&mut self) {
342 self.active_prefix = None;
343 self.cache.clear();
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::tensor::DType;
351
352 #[test]
353 fn test_prefix_id() {
354 let tokens1 = vec![1, 2, 3, 4];
355 let tokens2 = vec![1, 2, 3, 4];
356 let tokens3 = vec![1, 2, 3, 5];
357
358 let id1 = PrefixId::from_tokens(&tokens1);
359 let id2 = PrefixId::from_tokens(&tokens2);
360 let id3 = PrefixId::from_tokens(&tokens3);
361
362 assert_eq!(id1, id2);
363 assert_ne!(id1, id3);
364 }
365
366 #[test]
367 fn test_prompt_cache() {
368 let config = PromptCacheConfig {
369 min_prefix_len: 2,
370 ..Default::default()
371 };
372 let mut cache = PromptCache::new(config);
373
374 let tokens = vec![1, 2, 3, 4, 5];
375 let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
376 let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
377
378 let id = cache.cache_prefix(&tokens, k, v);
379
380 assert!(cache.get_prefix(&id).is_some());
381 assert_eq!(cache.stats().num_entries, 1);
382 }
383
384 #[test]
385 fn test_find_matching_prefix() {
386 let config = PromptCacheConfig {
387 min_prefix_len: 2,
388 ..Default::default()
389 };
390 let mut cache = PromptCache::new(config);
391
392 let prefix = vec![1, 2, 3];
393 let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
394 let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
395
396 cache.cache_prefix(&prefix, k, v);
397
398 let query = vec![1, 2, 3, 4, 5];
400 let result = cache.find_matching_prefix(&query);
401 assert!(result.is_some());
402 assert_eq!(result.unwrap().1, 3);
403
404 let query2 = vec![1, 2, 4, 5];
406 let result2 = cache.find_matching_prefix(&query2);
407 assert!(result2.is_none());
408 }
409
410 #[test]
411 fn test_cache_eviction() {
412 let config = PromptCacheConfig {
413 max_entries: 2,
414 min_prefix_len: 1,
415 ..Default::default()
416 };
417 let mut cache = PromptCache::new(config);
418
419 for i in 0..3 {
421 let tokens = vec![i];
422 let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
423 let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
424 cache.cache_prefix(&tokens, k, v);
425 }
426
427 assert!(cache.stats().num_entries <= 2);
428 }
429}