Skip to main content

offline_intelligence/cache_management/
cache_extractor.rs

1//! Extracts and preserves important KV cache entries
2
3use regex::Regex;
4use std::collections::HashMap;
5use tracing::{debug, trace};
6
7/// Types of KV cache entries that can be preserved
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub enum CacheEntryType {
10    AttentionKey,
11    AttentionValue,
12    FFNKey,
13    FFNValue,
14    SystemPrompt,
15    CodeBlock,
16    ImportantConcept,
17    Question,
18    NumericData,
19    Custom(String),
20}
21
22impl std::fmt::Display for CacheEntryType {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            CacheEntryType::AttentionKey => write!(f, "attention_key"),
26            CacheEntryType::AttentionValue => write!(f, "attention_value"),
27            CacheEntryType::FFNKey => write!(f, "ffn_key"),
28            CacheEntryType::FFNValue => write!(f, "ffn_value"),
29            CacheEntryType::SystemPrompt => write!(f, "system_prompt"),
30            CacheEntryType::CodeBlock => write!(f, "code_block"),
31            CacheEntryType::ImportantConcept => write!(f, "important_concept"),
32            CacheEntryType::Question => write!(f, "question"),
33            CacheEntryType::NumericData => write!(f, "numeric_data"),
34            CacheEntryType::Custom(name) => write!(f, "{}", name),
35        }
36    }
37}
38
39/// Represents a KV cache entry (simplified for extraction)
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct KVEntry {
42    pub key_hash: String,
43    pub key_data: Option<Vec<u8>>,
44    pub value_data: Vec<u8>,
45    pub key_type: String,  // attention_key, attention_value, ffn_key, ffn_value
46    pub layer_index: i32,  // Transformer layer index
47    pub head_index: Option<i32>,  // Attention head index (None for FFN)
48    pub importance_score: f32,
49    pub access_count: i32,
50    pub last_accessed: chrono::DateTime<chrono::Utc>,
51    pub token_positions: Option<Vec<u32>>,  // Token positions in the context
52    pub embedding: Option<Vec<f32>>,       // Optional embedding for semantic search
53    pub size_bytes: usize,                 // Actual size in bytes
54    pub is_persistent: bool,               // Whether this entry should be preserved
55}
56
57/// An extracted KV cache entry with its metadata
58#[derive(Debug, Clone)]
59pub struct ExtractedCacheEntry {
60    pub entry_type: CacheEntryType,
61    pub key_hash: String,
62    pub key_data: Option<Vec<u8>>,
63    pub value_data: Vec<u8>,
64    pub layer_index: i32,
65    pub head_index: Option<i32>,
66    pub importance_score: f32,
67    pub access_count: i32,
68    pub keywords: Vec<String>,
69}
70
71/// Extracts important KV cache entries
72pub struct CacheExtractor {
73    patterns: HashMap<CacheEntryType, Regex>,
74    config: CacheExtractorConfig,
75}
76
77#[derive(Debug, Clone)]
78pub struct CacheExtractorConfig {
79    pub min_value_size: usize,
80    pub max_value_size: usize,
81    pub extract_keywords: bool,
82    pub keyword_min_length: usize,
83}
84
85impl Default for CacheExtractorConfig {
86    fn default() -> Self {
87        Self {
88            min_value_size: 10,
89            max_value_size: 10000,
90            extract_keywords: true,
91            keyword_min_length: 3,
92        }
93    }
94}
95
96// Forward declare CacheEntryScorer trait
97pub trait CacheEntryScorer {
98    fn extract_keywords(&self, key_data: Option<&[u8]>) -> Vec<String>;
99}
100
101impl CacheExtractor {
102    /// Create a new cache extractor
103    pub fn new(config: CacheExtractorConfig) -> Self {
104        let mut patterns = HashMap::new();
105        
106        // System prompt patterns
107        patterns.insert(
108            CacheEntryType::SystemPrompt,
109            Regex::new(r"(?i)(system|instruction|prompt|assistant_role|you are|your role)").unwrap(),
110        );
111        
112        // Code block patterns
113        patterns.insert(
114            CacheEntryType::CodeBlock,
115            Regex::new(r"```|\b(def|function|class|import|return|print|let|const|var)\b|\b(python|rust|javascript|java|c\+\+|go|sql)\b").unwrap(),
116        );
117        
118        // Important concept patterns
119        patterns.insert(
120            CacheEntryType::ImportantConcept,
121            Regex::new(r"(?i)\b(important|crucial|critical|essential|must|need|require|urgent|priority|key|main|primary)\b").unwrap(),
122        );
123        
124        // Question patterns
125        patterns.insert(
126            CacheEntryType::Question,
127            Regex::new(r"\?$|^(what|how|why|when|where|who|explain|describe|can you|could you|would you|should you)").unwrap(),
128        );
129        
130        // Numeric data patterns
131        patterns.insert(
132            CacheEntryType::NumericData,
133            Regex::new(r"\b\d+(?:\.\d+)?%?\b|\b(date|time|age|year|month|day|hour|minute|second)\b").unwrap(),
134        );
135        
136        Self { patterns, config }
137    }
138    
139    /// Add a custom pattern
140    pub fn add_custom_pattern(&mut self, name: String, pattern: Regex) {
141        self.patterns.insert(CacheEntryType::Custom(name), pattern);
142    }
143    
144    /// Extract important entries from KV cache
145    pub fn extract_entries(
146        &self,
147        entries: &[KVEntry],
148        scorer: &impl CacheEntryScorer,
149    ) -> Vec<ExtractedCacheEntry> {
150        let mut extracted = Vec::new();
151        
152        for entry in entries {
153            // Skip if value size is outside bounds
154            if entry.value_data.len() < self.config.min_value_size 
155                || entry.value_data.len() > self.config.max_value_size {
156                continue;
157            }
158            
159            // Determine entry type based on key patterns
160            let entry_type = self.classify_entry(entry);
161            
162            // Extract keywords if enabled
163            let keywords = if self.config.extract_keywords {
164                scorer.extract_keywords(entry.key_data.as_deref())
165            } else {
166                Vec::new()
167            };
168            
169            let extracted_entry = ExtractedCacheEntry {
170                entry_type,
171                key_hash: entry.key_hash.clone(),
172                key_data: entry.key_data.clone(),
173                value_data: entry.value_data.clone(),
174                layer_index: entry.layer_index,
175                head_index: entry.head_index,
176                importance_score: entry.importance_score,
177                access_count: entry.access_count,
178                keywords,
179            };
180            
181            trace!("Extracted cache entry: {} (score: {})", 
182                extracted_entry.entry_type, extracted_entry.importance_score);
183            
184            extracted.push(extracted_entry);
185        }
186        
187        // Sort by importance score
188        extracted.sort_by(|a, b| b.importance_score.partial_cmp(&a.importance_score)
189            .unwrap_or(std::cmp::Ordering::Equal));
190        
191        debug!("Extracted {} important cache entries", extracted.len());
192        extracted
193    }
194    
195    fn classify_entry(&self, entry: &KVEntry) -> CacheEntryType {
196        // First check key type
197        let key_type_str = entry.key_type.as_str();
198        let base_type = match key_type_str {
199            "attention_key" => CacheEntryType::AttentionKey,
200            "attention_value" => CacheEntryType::AttentionValue,
201            "ffn_key" => CacheEntryType::FFNKey,
202            "ffn_value" => CacheEntryType::FFNValue,
203            _ => CacheEntryType::AttentionKey, // Default
204        };
205        
206        // Check for special patterns in key data
207        if let Some(key_data) = &entry.key_data {
208            if let Ok(key_str) = std::str::from_utf8(key_data) {
209                for (entry_type, pattern) in &self.patterns {
210                    if pattern.is_match(key_str) {
211                        // Return the first matching pattern
212                        return entry_type.clone();
213                    }
214                }
215            }
216        }
217        
218        base_type
219    }
220    
221    /// Filter entries that should be preserved
222    pub fn filter_preserved_entries(
223        &self,
224        entries: &[ExtractedCacheEntry],
225        min_importance: f32,
226        preserve_system: bool,
227        preserve_code: bool,
228    ) -> Vec<ExtractedCacheEntry> {
229        entries.iter()
230            .filter(|entry| {
231                // Check importance threshold
232                if entry.importance_score < min_importance {
233                    return false;
234                }
235                
236                // Check specific preservation rules
237                match &entry.entry_type {
238                    CacheEntryType::SystemPrompt if preserve_system => true,
239                    CacheEntryType::CodeBlock if preserve_code => true,
240                    CacheEntryType::ImportantConcept => true,
241                    CacheEntryType::AttentionKey | CacheEntryType::AttentionValue => true,
242                    _ => entry.importance_score >= min_importance * 1.2, // Higher threshold for others
243                }
244            })
245            .cloned()
246            .collect()
247    }
248}