Skip to main content

llama_gguf/model/
cache.rs

1//! Prompt caching and prefix sharing
2//!
3//! This module provides mechanisms to cache and reuse KV cache entries
4//! for common prompt prefixes, enabling faster inference for:
5//! - System prompts that are reused across conversations
6//! - Common instruction prefixes
7//! - RAG context that's shared across queries
8
9use std::collections::HashMap;
10use std::hash::{Hash, Hasher};
11
12use crate::tensor::Tensor;
13
14/// Unique identifier for a cached prefix
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct PrefixId(pub u64);
17
18impl PrefixId {
19    /// Create a prefix ID from tokens
20    pub fn from_tokens(tokens: &[u32]) -> Self {
21        use std::collections::hash_map::DefaultHasher;
22        let mut hasher = DefaultHasher::new();
23        tokens.hash(&mut hasher);
24        PrefixId(hasher.finish())
25    }
26}
27
28/// Cached KV state for a prefix
29#[derive(Debug, Clone)]
30pub struct CachedPrefix {
31    /// The tokens that make up this prefix
32    pub tokens: Vec<u32>,
33    /// Cached key tensors per layer
34    pub k_cache: Vec<Tensor>,
35    /// Cached value tensors per layer  
36    pub v_cache: Vec<Tensor>,
37    /// Number of tokens cached
38    pub seq_len: usize,
39    /// Reference count (for LRU eviction)
40    pub ref_count: usize,
41    /// Last access time
42    pub last_access: std::time::Instant,
43}
44
45impl CachedPrefix {
46    /// Create a new cached prefix
47    pub fn new(tokens: Vec<u32>, k_cache: Vec<Tensor>, v_cache: Vec<Tensor>) -> Self {
48        let seq_len = tokens.len();
49        Self {
50            tokens,
51            k_cache,
52            v_cache,
53            seq_len,
54            ref_count: 0,
55            last_access: std::time::Instant::now(),
56        }
57    }
58
59    /// Memory size in bytes
60    pub fn memory_size(&self) -> usize {
61        let k_size: usize = self.k_cache.iter().map(|t| t.data().len()).sum();
62        let v_size: usize = self.v_cache.iter().map(|t| t.data().len()).sum();
63        k_size + v_size + self.tokens.len() * 4
64    }
65}
66
67/// Prompt cache configuration
68#[derive(Debug, Clone)]
69pub struct PromptCacheConfig {
70    /// Maximum number of cached prefixes
71    pub max_entries: usize,
72    /// Maximum total memory for cache (bytes)
73    pub max_memory: usize,
74    /// Minimum prefix length to cache
75    pub min_prefix_len: usize,
76    /// Enable automatic caching of system prompts
77    pub cache_system_prompts: bool,
78}
79
80impl Default for PromptCacheConfig {
81    fn default() -> Self {
82        Self {
83            max_entries: 100,
84            max_memory: 1024 * 1024 * 1024, // 1 GB
85            min_prefix_len: 32,
86            cache_system_prompts: true,
87        }
88    }
89}
90
91/// Prompt cache for prefix sharing
92pub struct PromptCache {
93    /// Configuration
94    config: PromptCacheConfig,
95    /// Cached prefixes by ID
96    entries: HashMap<PrefixId, CachedPrefix>,
97    /// Current memory usage
98    memory_used: usize,
99}
100
101impl PromptCache {
102    /// Create a new prompt cache
103    pub fn new(config: PromptCacheConfig) -> Self {
104        Self {
105            config,
106            entries: HashMap::new(),
107            memory_used: 0,
108        }
109    }
110
111    /// Cache a prefix
112    pub fn cache_prefix(
113        &mut self,
114        tokens: &[u32],
115        k_cache: Vec<Tensor>,
116        v_cache: Vec<Tensor>,
117    ) -> PrefixId {
118        let id = PrefixId::from_tokens(tokens);
119
120        // Check if already cached
121        if self.entries.contains_key(&id) {
122            if let Some(entry) = self.entries.get_mut(&id) {
123                entry.ref_count += 1;
124                entry.last_access = std::time::Instant::now();
125            }
126            return id;
127        }
128
129        // Check if prefix is long enough
130        if tokens.len() < self.config.min_prefix_len {
131            return id;
132        }
133
134        let prefix = CachedPrefix::new(tokens.to_vec(), k_cache, v_cache);
135        let size = prefix.memory_size();
136
137        // Evict if necessary
138        while self.memory_used + size > self.config.max_memory
139            || self.entries.len() >= self.config.max_entries
140        {
141            if !self.evict_lru() {
142                break;
143            }
144        }
145
146        self.memory_used += size;
147        self.entries.insert(id.clone(), prefix);
148
149        id
150    }
151
152    /// Get a cached prefix
153    pub fn get_prefix(&mut self, id: &PrefixId) -> Option<&CachedPrefix> {
154        if let Some(entry) = self.entries.get_mut(id) {
155            entry.ref_count += 1;
156            entry.last_access = std::time::Instant::now();
157            Some(entry)
158        } else {
159            None
160        }
161    }
162
163    /// Find the longest matching prefix
164    pub fn find_matching_prefix(&mut self, tokens: &[u32]) -> Option<(PrefixId, usize)> {
165        let mut best_match: Option<(PrefixId, usize)> = None;
166
167        for (id, entry) in &self.entries {
168            // Check if this prefix matches the start of tokens
169            if tokens.len() >= entry.tokens.len()
170                && tokens[..entry.tokens.len()] == entry.tokens[..]
171            {
172                let match_len = entry.tokens.len();
173                if best_match.is_none() || match_len > best_match.as_ref().unwrap().1 {
174                    best_match = Some((id.clone(), match_len));
175                }
176            }
177        }
178
179        // Update access time for matched entry
180        if let Some((ref id, _)) = best_match
181            && let Some(entry) = self.entries.get_mut(id)
182        {
183            entry.last_access = std::time::Instant::now();
184            entry.ref_count += 1;
185        }
186
187        best_match
188    }
189
190    /// Remove a prefix from cache
191    pub fn remove_prefix(&mut self, id: &PrefixId) {
192        if let Some(entry) = self.entries.remove(id) {
193            self.memory_used = self.memory_used.saturating_sub(entry.memory_size());
194        }
195    }
196
197    /// Clear all cached prefixes
198    pub fn clear(&mut self) {
199        self.entries.clear();
200        self.memory_used = 0;
201    }
202
203    /// Get cache statistics
204    pub fn stats(&self) -> PromptCacheStats {
205        PromptCacheStats {
206            num_entries: self.entries.len(),
207            memory_used: self.memory_used,
208            total_tokens_cached: self.entries.values().map(|e| e.seq_len).sum(),
209        }
210    }
211
212    /// Evict the least recently used entry
213    fn evict_lru(&mut self) -> bool {
214        // Find LRU entry (oldest last_access with ref_count == 0)
215        let lru_id = self
216            .entries
217            .iter()
218            .filter(|(_, e)| e.ref_count == 0)
219            .min_by_key(|(_, e)| e.last_access)
220            .map(|(id, _)| id.clone());
221
222        if let Some(id) = lru_id {
223            self.remove_prefix(&id);
224            true
225        } else {
226            false
227        }
228    }
229
230    /// Decrease reference count for a prefix
231    pub fn release_prefix(&mut self, id: &PrefixId) {
232        if let Some(entry) = self.entries.get_mut(id) {
233            entry.ref_count = entry.ref_count.saturating_sub(1);
234        }
235    }
236}
237
238/// Cache statistics
239#[derive(Debug, Clone)]
240pub struct PromptCacheStats {
241    /// Number of cached prefixes
242    pub num_entries: usize,
243    /// Memory used in bytes
244    pub memory_used: usize,
245    /// Total tokens cached
246    pub total_tokens_cached: usize,
247}
248
249/// Helper to manage prefix sharing in inference
250pub struct PrefixSharing {
251    /// The prompt cache
252    cache: PromptCache,
253    /// Active prefix ID for current session
254    active_prefix: Option<PrefixId>,
255}
256
257impl PrefixSharing {
258    /// Create a new prefix sharing manager
259    pub fn new(config: PromptCacheConfig) -> Self {
260        Self {
261            cache: PromptCache::new(config),
262            active_prefix: None,
263        }
264    }
265
266    /// Try to restore cached prefix into KV cache
267    ///
268    /// Returns the number of tokens restored (0 if no match)
269    pub fn try_restore(
270        &mut self,
271        tokens: &[u32],
272        k_cache: &mut [Tensor],
273        v_cache: &mut [Tensor],
274    ) -> usize {
275        // Find matching prefix
276        let (id, match_len) = match self.cache.find_matching_prefix(tokens) {
277            Some(m) => m,
278            None => return 0,
279        };
280
281        // Get cached data
282        let prefix = match self.cache.get_prefix(&id) {
283            Some(p) => p,
284            None => return 0,
285        };
286
287        // Copy cached KV to current cache
288        for (layer_idx, (cached_k, cached_v)) in
289            prefix.k_cache.iter().zip(prefix.v_cache.iter()).enumerate()
290        {
291            if layer_idx < k_cache.len() {
292                // Copy cached data
293                let k_src = cached_k.data();
294                let v_src = cached_v.data();
295
296                if let Some(k_dst) = k_cache[layer_idx].data_mut() {
297                    let copy_len = k_src.len().min(k_dst.len());
298                    k_dst[..copy_len].copy_from_slice(&k_src[..copy_len]);
299                }
300
301                if let Some(v_dst) = v_cache[layer_idx].data_mut() {
302                    let copy_len = v_src.len().min(v_dst.len());
303                    v_dst[..copy_len].copy_from_slice(&v_src[..copy_len]);
304                }
305            }
306        }
307
308        self.active_prefix = Some(id);
309        match_len
310    }
311
312    /// Save current KV cache as a prefix
313    pub fn save_prefix(
314        &mut self,
315        tokens: &[u32],
316        k_cache: &[Tensor],
317        v_cache: &[Tensor],
318    ) -> PrefixId {
319        // Clone the cache tensors
320        let k_cloned: Vec<Tensor> = k_cache.to_vec();
321        let v_cloned: Vec<Tensor> = v_cache.to_vec();
322
323        let id = self.cache.cache_prefix(tokens, k_cloned, v_cloned);
324        self.active_prefix = Some(id.clone());
325        id
326    }
327
328    /// Release the active prefix
329    pub fn release_active(&mut self) {
330        if let Some(id) = self.active_prefix.take() {
331            self.cache.release_prefix(&id);
332        }
333    }
334
335    /// Get cache statistics
336    pub fn stats(&self) -> PromptCacheStats {
337        self.cache.stats()
338    }
339
340    /// Clear all cached prefixes
341    pub fn clear(&mut self) {
342        self.active_prefix = None;
343        self.cache.clear();
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::tensor::DType;
351
352    #[test]
353    fn test_prefix_id() {
354        let tokens1 = vec![1, 2, 3, 4];
355        let tokens2 = vec![1, 2, 3, 4];
356        let tokens3 = vec![1, 2, 3, 5];
357
358        let id1 = PrefixId::from_tokens(&tokens1);
359        let id2 = PrefixId::from_tokens(&tokens2);
360        let id3 = PrefixId::from_tokens(&tokens3);
361
362        assert_eq!(id1, id2);
363        assert_ne!(id1, id3);
364    }
365
366    #[test]
367    fn test_prompt_cache() {
368        let config = PromptCacheConfig {
369            min_prefix_len: 2,
370            ..Default::default()
371        };
372        let mut cache = PromptCache::new(config);
373
374        let tokens = vec![1, 2, 3, 4, 5];
375        let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
376        let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
377
378        let id = cache.cache_prefix(&tokens, k, v);
379
380        assert!(cache.get_prefix(&id).is_some());
381        assert_eq!(cache.stats().num_entries, 1);
382    }
383
384    #[test]
385    fn test_find_matching_prefix() {
386        let config = PromptCacheConfig {
387            min_prefix_len: 2,
388            ..Default::default()
389        };
390        let mut cache = PromptCache::new(config);
391
392        let prefix = vec![1, 2, 3];
393        let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
394        let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
395
396        cache.cache_prefix(&prefix, k, v);
397
398        // Should match
399        let query = vec![1, 2, 3, 4, 5];
400        let result = cache.find_matching_prefix(&query);
401        assert!(result.is_some());
402        assert_eq!(result.unwrap().1, 3);
403
404        // Should not match
405        let query2 = vec![1, 2, 4, 5];
406        let result2 = cache.find_matching_prefix(&query2);
407        assert!(result2.is_none());
408    }
409
410    #[test]
411    fn test_cache_eviction() {
412        let config = PromptCacheConfig {
413            max_entries: 2,
414            min_prefix_len: 1,
415            ..Default::default()
416        };
417        let mut cache = PromptCache::new(config);
418
419        // Add 3 entries, should evict one
420        for i in 0..3 {
421            let tokens = vec![i];
422            let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
423            let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
424            cache.cache_prefix(&tokens, k, v);
425        }
426
427        assert!(cache.stats().num_entries <= 2);
428    }
429}