Skip to main content

offline_intelligence/cache_management/
cache_scorer.rs

1//! Scores the importance of KV cache entries for retention and retrieval
2
3use regex::Regex;
4use std::collections::HashMap;
5use lazy_static::lazy_static;
6
7lazy_static! {
8    static ref KEY_PATTERNS: HashMap<&'static str, Regex> = {
9        let mut m = HashMap::new();
10        
11        // System prompt patterns
12        m.insert(
13            "system_prompt",
14            Regex::new(r"system|instruction|prompt|assistant_role").unwrap(),
15        );
16        
17        // Code-related patterns
18        m.insert(
19            "code_related",
20            Regex::new(r"def |function |class |import |return |print |code|program|algorithm|python|rust|javascript|java|c\+\+|sql|```").unwrap(),
21        );
22        
23        // Important concept patterns
24        m.insert(
25            "important_concept",
26            Regex::new(r"important|critical|crucial|essential|must|need|require|urgent|asap|priority|key|main|primary").unwrap(),
27        );
28        
29        // Question patterns
30        m.insert(
31            "question",
32            Regex::new(r"what|how|why|when|where|who|explain|describe|can you|could you|would you|should").unwrap(),
33        );
34        
35        // Number/date patterns
36        m.insert(
37            "numeric",
38            Regex::new(r"\d+|date|time|age|year|month|day|hour|minute|second").unwrap(),
39        );
40        
41        m
42    };
43}
44
45/// Parameters for scoring a cache entry
46pub struct CacheEntryParams<'a> {
47    pub key_hash: &'a str,
48    pub key_data: Option<&'a [u8]>,
49    pub key_type: &'a str,
50    pub layer_index: i32,
51    pub head_index: Option<i32>,
52    pub access_count: i32,
53    pub last_accessed_seconds_ago: f32,
54    pub value_size_bytes: usize,
55}
56
57/// Scores importance of KV cache entries
58pub struct CacheEntryScorer {
59    key_engagement: HashMap<String, f32>, // Tracks frequently accessed keys
60    config: CacheScoringConfig,
61}
62
63#[derive(Debug, Clone)]
64pub struct CacheScoringConfig {
65    pub recency_weight: f32,
66    pub access_count_weight: f32,
67    pub key_pattern_weight: f32,
68    pub layer_weight: f32,
69    pub head_weight: f32,
70    pub value_size_weight: f32,
71    pub engagement_decay: f32,
72    pub min_engagement: f32,
73    pub max_engagement: f32,
74}
75
76impl Default for CacheScoringConfig {
77    fn default() -> Self {
78        Self {
79            recency_weight: 0.3,
80            access_count_weight: 0.2,
81            key_pattern_weight: 0.25,
82            layer_weight: 0.1,
83            head_weight: 0.05,
84            value_size_weight: 0.1,
85            engagement_decay: 0.95,
86            min_engagement: 0.1,
87            max_engagement: 1.0,
88        }
89    }
90}
91
92impl CacheEntryScorer {
93    /// Create a new cache entry scorer
94    pub fn new(config: CacheScoringConfig) -> Self {
95        Self {
96            key_engagement: HashMap::new(),
97            config,
98        }
99    }
100
101    /// Score a KV cache entry based on various factors
102    pub fn score_entry(&self, params: CacheEntryParams) -> f32 {
103        let mut score = 0.0;
104
105        score += self.score_recency(params.last_accessed_seconds_ago);
106        score += self.score_access_count(params.access_count);
107        score += self.score_key_patterns(params.key_data, params.key_type);
108        score += self.score_layer_position(params.layer_index);
109        score += self.score_head_position(params.head_index);
110        score += self.score_value_size(params.value_size_bytes);
111        score += self.score_key_engagement(params.key_hash);
112
113        score.clamp(0.0, 1.0)
114    }
115
116    fn score_recency(&self, seconds_ago: f32) -> f32 {
117        let recency_factor = 1.0 / (1.0 + seconds_ago / 3600.0); // Hours decay
118        recency_factor * self.config.recency_weight
119    }
120
121    fn score_access_count(&self, access_count: i32) -> f32 {
122        let normalized = (access_count as f32).min(100.0) / 100.0;
123        normalized * self.config.access_count_weight
124    }
125
126    fn score_key_patterns(&self, key_data: Option<&[u8]>, key_type: &str) -> f32 {
127        // Explicitly specify f32 type to fix ambiguity
128        let mut pattern_score: f32 = 0.0; 
129        
130        // Check key type
131        match key_type {
132            "attention_key" | "attention_value" => pattern_score += 0.1,
133            "ffn_key" | "ffn_value" => pattern_score += 0.05,
134            _ => {}
135        }
136        
137        // Check key data patterns if available
138        if let Some(data) = key_data {
139            if let Ok(key_str) = std::str::from_utf8(data) {
140                for (pattern_name, regex) in KEY_PATTERNS.iter() {
141                    if regex.is_match(key_str) {
142                        let weight = match *pattern_name {
143                            "system_prompt" => 0.8,
144                            "code_related" => 0.7,
145                            "important_concept" => 0.9,
146                            "question" => 0.6,
147                            "numeric" => 0.5,
148                            _ => 0.3,
149                        };
150                        pattern_score += weight;
151                    }
152                }
153            }
154        }
155        
156        pattern_score.min(1.0) * self.config.key_pattern_weight
157    }
158
159    fn score_layer_position(&self, layer_index: i32) -> f32 {
160        // Early layers (0-10) are more important than middle layers
161        let layer_factor = if layer_index < 10 {
162            0.9
163        } else if layer_index < 20 {
164            0.7
165        } else {
166            0.5
167        };
168        layer_factor * self.config.layer_weight
169    }
170
171    fn score_head_position(&self, head_index: Option<i32>) -> f32 {
172        if let Some(head) = head_index {
173            // First few heads often capture important patterns
174            let head_factor = if head < 4 { 0.8 } else { 0.5 };
175            head_factor * self.config.head_weight
176        } else {
177            0.0
178        }
179    }
180
181    fn score_value_size(&self, size_bytes: usize) -> f32 {
182        // Larger values might be more important (more context)
183        let size_factor = (size_bytes as f32).min(10000.0) / 10000.0;
184        size_factor * self.config.value_size_weight
185    }
186
187    fn score_key_engagement(&self, key_hash: &str) -> f32 {
188        self.key_engagement.get(key_hash).map_or(0.0, |&e| e * 0.3)
189    }
190
191    pub fn update_engagement(&mut self, key_hash: &str, was_retrieved: bool) {
192        let engagement_increase = if was_retrieved { 0.15 } else { 0.05 };
193        
194        let current = self.key_engagement.entry(key_hash.to_string()).or_insert(0.3);
195        *current = (*current + engagement_increase)
196            .min(self.config.max_engagement)
197            .max(self.config.min_engagement);
198        
199        // Decay other keys
200        self.decay_other_keys(key_hash);
201    }
202
203    fn decay_other_keys(&mut self, current_key: &str) {
204        for (key, engagement) in self.key_engagement.iter_mut() {
205            if *key != current_key {
206                *engagement = (*engagement * self.config.engagement_decay)
207                    .max(self.config.min_engagement);
208            }
209        }
210    }
211
212    /// Determine if an entry should be preserved during cache clearing
213    pub fn should_preserve_entry(
214        &self,
215        importance_score: f32,
216        key_type: &str,
217        layer_index: i32,
218        config_threshold: f32,
219    ) -> bool {
220        let base_preservation = match key_type {
221            "attention_key" | "attention_value" => 0.8,
222            "ffn_key" | "ffn_value" => 0.6,
223            _ => 0.5,
224        };
225        
226        let layer_factor = if layer_index < 8 { 1.2 } else { 1.0 };
227        let combined_score = importance_score * layer_factor;
228        
229        combined_score >= config_threshold || base_preservation >= 0.7
230    }
231
232    /// Extract keywords from a key for retrieval
233    pub fn extract_keywords(&self, key_data: Option<&[u8]>) -> Vec<String> {
234        let mut keywords = Vec::new();
235        
236        if let Some(data) = key_data {
237            if let Ok(key_str) = std::str::from_utf8(data) {
238                // Simple keyword extraction
239                let words: Vec<&str> = key_str.split_whitespace().collect();
240                for word in words.iter().filter(|w| w.len() > 3) {
241                    let word_lower = word.to_lowercase();
242                    
243                    // Check if it's a meaningful word
244                    if !self.is_stop_word(&word_lower) {
245                        keywords.push(word_lower);
246                    }
247                }
248            }
249        }
250        
251        keywords.dedup();
252        keywords.truncate(5); // Limit to top 5 keywords
253        keywords
254    }
255    
256    fn is_stop_word(&self, word: &str) -> bool {
257        let stop_words = [
258            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
259            "of", "with", "by", "is", "am", "are", "was", "were", "be", "been",
260            "being", "have", "has", "had", "do", "does", "did", "will", "would",
261            "shall", "should", "may", "might", "must", "can", "could", "this",
262            "that", "these", "those", "it", "its", "it's",
263        ];
264        stop_words.contains(&word)
265    }
266}
267
268/// Score the importance of a conversation message based on its role and content.
269///
270/// Returns a value in [0.1, 0.95]. Used when persisting messages to SQLite so that
271/// retrieval strategies (ImportanceFiltered, hybrid ranking) have meaningful scores
272/// to work with rather than a flat 0.5 for everything.
273pub fn score_message_importance(role: &str, content: &str) -> f32 {
274    // Role base: system prompts are always high-value anchors; assistant responses
275    // carry more information density than user turns on average.
276    let role_base: f32 = match role {
277        "system" => 0.9,
278        "assistant" => 0.6,
279        _ => 0.4, // user (and any other role)
280    };
281
282    // Content bonus: scan for signals that indicate a high-information message.
283    let mut content_bonus: f32 = 0.0;
284
285    // Code blocks are almost always important context to preserve.
286    if content.contains("```") {
287        content_bonus += 0.2;
288    }
289
290    // Pattern-based signals using the same regexes the KV cache scorer uses for keys.
291    for (pattern_name, regex) in KEY_PATTERNS.iter() {
292        if regex.is_match(content) {
293            content_bonus += match *pattern_name {
294                "important_concept" => 0.15,
295                "code_related" => 0.10,
296                "system_prompt" => 0.10,
297                "question" => 0.05,
298                "numeric" => 0.04,
299                _ => 0.02,
300            };
301        }
302    }
303    // Cap so a single very rich message doesn't dominate everything else.
304    let content_bonus = content_bonus.min(0.35);
305
306    // Length bonus: longer messages carry more information (log-like saturation at ~3000 chars).
307    let length_bonus = ((content.len() as f32) / 3000.0).min(0.1);
308
309    (role_base + content_bonus + length_bonus).clamp(0.1, 0.95)
310}
311
312/// Implementation of the trait required by the cache_extractor module
313impl crate::cache_management::cache_extractor::CacheEntryScorer for CacheEntryScorer {
314    fn extract_keywords(&self, key_data: Option<&[u8]>) -> Vec<String> {
315        // Reuse the logic defined in the inherent impl
316        self.extract_keywords(key_data)
317    }
318}