Skip to main content

offline_intelligence/cache_management/
cache_extractor.rs

1//! Extracts and preserves important KV cache entries
2
3use regex::Regex;
4use std::collections::HashMap;
5use tracing::{debug, trace};
6
7/// Types of KV cache entries that can be preserved
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub enum CacheEntryType {
10    AttentionKey,
11    AttentionValue,
12    FFNKey,
13    FFNValue,
14    SystemPrompt,
15    CodeBlock,
16    ImportantConcept,
17    Question,
18    NumericData,
19    Custom(String),
20}
21
22impl std::fmt::Display for CacheEntryType {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            CacheEntryType::AttentionKey => write!(f, "attention_key"),
26            CacheEntryType::AttentionValue => write!(f, "attention_value"),
27            CacheEntryType::FFNKey => write!(f, "ffn_key"),
28            CacheEntryType::FFNValue => write!(f, "ffn_value"),
29            CacheEntryType::SystemPrompt => write!(f, "system_prompt"),
30            CacheEntryType::CodeBlock => write!(f, "code_block"),
31            CacheEntryType::ImportantConcept => write!(f, "important_concept"),
32            CacheEntryType::Question => write!(f, "question"),
33            CacheEntryType::NumericData => write!(f, "numeric_data"),
34            CacheEntryType::Custom(name) => write!(f, "{}", name),
35        }
36    }
37}
38
39/// Represents a KV cache entry (simplified for extraction)
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]  // FIX: Added Serialize and Deserialize
41pub struct KVEntry {
42    pub key_hash: String,
43    pub key_data: Option<Vec<u8>>,
44    pub value_data: Vec<u8>,
45    pub key_type: String,
46    pub layer_index: i32,
47    pub head_index: Option<i32>,
48    pub importance_score: f32,
49    pub access_count: i32,
50    pub last_accessed: chrono::DateTime<chrono::Utc>,
51}
52
53/// An extracted KV cache entry with its metadata
54#[derive(Debug, Clone)]
55pub struct ExtractedCacheEntry {
56    pub entry_type: CacheEntryType,
57    pub key_hash: String,
58    pub key_data: Option<Vec<u8>>,
59    pub value_data: Vec<u8>,
60    pub layer_index: i32,
61    pub head_index: Option<i32>,
62    pub importance_score: f32,
63    pub access_count: i32,
64    pub keywords: Vec<String>,
65}
66
67/// Extracts important KV cache entries
68pub struct CacheExtractor {
69    patterns: HashMap<CacheEntryType, Regex>,
70    config: CacheExtractorConfig,
71}
72
73#[derive(Debug, Clone)]
74pub struct CacheExtractorConfig {
75    pub min_value_size: usize,
76    pub max_value_size: usize,
77    pub extract_keywords: bool,
78    pub keyword_min_length: usize,
79}
80
81impl Default for CacheExtractorConfig {
82    fn default() -> Self {
83        Self {
84            min_value_size: 10,
85            max_value_size: 10000,
86            extract_keywords: true,
87            keyword_min_length: 3,
88        }
89    }
90}
91
92// Forward declare CacheEntryScorer trait
93pub trait CacheEntryScorer {
94    fn extract_keywords(&self, key_data: Option<&[u8]>) -> Vec<String>;
95}
96
97impl CacheExtractor {
98    /// Create a new cache extractor
99    pub fn new(config: CacheExtractorConfig) -> Self {
100        let mut patterns = HashMap::new();
101        
102        // System prompt patterns
103        patterns.insert(
104            CacheEntryType::SystemPrompt,
105            Regex::new(r"(?i)(system|instruction|prompt|assistant_role|you are|your role)").unwrap(),
106        );
107        
108        // Code block patterns
109        patterns.insert(
110            CacheEntryType::CodeBlock,
111            Regex::new(r"```|\b(def|function|class|import|return|print|let|const|var)\b|\b(python|rust|javascript|java|c\+\+|go|sql)\b").unwrap(),
112        );
113        
114        // Important concept patterns
115        patterns.insert(
116            CacheEntryType::ImportantConcept,
117            Regex::new(r"(?i)\b(important|crucial|critical|essential|must|need|require|urgent|priority|key|main|primary)\b").unwrap(),
118        );
119        
120        // Question patterns
121        patterns.insert(
122            CacheEntryType::Question,
123            Regex::new(r"\?$|^(what|how|why|when|where|who|explain|describe|can you|could you|would you|should you)").unwrap(),
124        );
125        
126        // Numeric data patterns
127        patterns.insert(
128            CacheEntryType::NumericData,
129            Regex::new(r"\b\d+(?:\.\d+)?%?\b|\b(date|time|age|year|month|day|hour|minute|second)\b").unwrap(),
130        );
131        
132        Self { patterns, config }
133    }
134    
135    /// Add a custom pattern
136    pub fn add_custom_pattern(&mut self, name: String, pattern: Regex) {
137        self.patterns.insert(CacheEntryType::Custom(name), pattern);
138    }
139    
140    /// Extract important entries from KV cache
141    pub fn extract_entries(
142        &self,
143        entries: &[KVEntry],
144        scorer: &impl CacheEntryScorer,
145    ) -> Vec<ExtractedCacheEntry> {
146        let mut extracted = Vec::new();
147        
148        for entry in entries {
149            // Skip if value size is outside bounds
150            if entry.value_data.len() < self.config.min_value_size 
151                || entry.value_data.len() > self.config.max_value_size {
152                continue;
153            }
154            
155            // Determine entry type based on key patterns
156            let entry_type = self.classify_entry(entry);
157            
158            // Extract keywords if enabled
159            let keywords = if self.config.extract_keywords {
160                scorer.extract_keywords(entry.key_data.as_deref())
161            } else {
162                Vec::new()
163            };
164            
165            let extracted_entry = ExtractedCacheEntry {
166                entry_type,
167                key_hash: entry.key_hash.clone(),
168                key_data: entry.key_data.clone(),
169                value_data: entry.value_data.clone(),
170                layer_index: entry.layer_index,
171                head_index: entry.head_index,
172                importance_score: entry.importance_score,
173                access_count: entry.access_count,
174                keywords,
175            };
176            
177            trace!("Extracted cache entry: {} (score: {})", 
178                extracted_entry.entry_type, extracted_entry.importance_score);
179            
180            extracted.push(extracted_entry);
181        }
182        
183        // Sort by importance score
184        extracted.sort_by(|a, b| b.importance_score.partial_cmp(&a.importance_score)
185            .unwrap_or(std::cmp::Ordering::Equal));
186        
187        debug!("Extracted {} important cache entries", extracted.len());
188        extracted
189    }
190    
191    fn classify_entry(&self, entry: &KVEntry) -> CacheEntryType {
192        // First check key type
193        let key_type_str = entry.key_type.as_str();
194        let base_type = match key_type_str {
195            "attention_key" => CacheEntryType::AttentionKey,
196            "attention_value" => CacheEntryType::AttentionValue,
197            "ffn_key" => CacheEntryType::FFNKey,
198            "ffn_value" => CacheEntryType::FFNValue,
199            _ => CacheEntryType::AttentionKey, // Default
200        };
201        
202        // Check for special patterns in key data
203        if let Some(key_data) = &entry.key_data {
204            if let Ok(key_str) = std::str::from_utf8(key_data) {
205                for (entry_type, pattern) in &self.patterns {
206                    if pattern.is_match(key_str) {
207                        // Return the first matching pattern
208                        return entry_type.clone();
209                    }
210                }
211            }
212        }
213        
214        base_type
215    }
216    
217    /// Filter entries that should be preserved
218    pub fn filter_preserved_entries(
219        &self,
220        entries: &[ExtractedCacheEntry],
221        min_importance: f32,
222        preserve_system: bool,
223        preserve_code: bool,
224    ) -> Vec<ExtractedCacheEntry> {
225        entries.iter()
226            .filter(|entry| {
227                // Check importance threshold
228                if entry.importance_score < min_importance {
229                    return false;
230                }
231                
232                // Check specific preservation rules
233                match &entry.entry_type {
234                    CacheEntryType::SystemPrompt if preserve_system => true,
235                    CacheEntryType::CodeBlock if preserve_code => true,
236                    CacheEntryType::ImportantConcept => true,
237                    CacheEntryType::AttentionKey | CacheEntryType::AttentionValue => true,
238                    _ => entry.importance_score >= min_importance * 1.2, // Higher threshold for others
239                }
240            })
241            .cloned()
242            .collect()
243    }
244}