offline_intelligence/cache_management/
cache_scorer.rs1use regex::Regex;
4use std::collections::HashMap;
5use lazy_static::lazy_static;
6
7lazy_static! {
8 static ref KEY_PATTERNS: HashMap<&'static str, Regex> = {
9 let mut m = HashMap::new();
10
11 m.insert(
13 "system_prompt",
14 Regex::new(r"system|instruction|prompt|assistant_role").unwrap(),
15 );
16
17 m.insert(
19 "code_related",
20 Regex::new(r"def |function |class |import |return |print |code|program|algorithm|python|rust|javascript|java|c\+\+|sql|```").unwrap(),
21 );
22
23 m.insert(
25 "important_concept",
26 Regex::new(r"important|critical|crucial|essential|must|need|require|urgent|asap|priority|key|main|primary").unwrap(),
27 );
28
29 m.insert(
31 "question",
32 Regex::new(r"what|how|why|when|where|who|explain|describe|can you|could you|would you|should").unwrap(),
33 );
34
35 m.insert(
37 "numeric",
38 Regex::new(r"\d+|date|time|age|year|month|day|hour|minute|second").unwrap(),
39 );
40
41 m
42 };
43}
44
45pub struct CacheEntryParams<'a> {
47 pub key_hash: &'a str,
48 pub key_data: Option<&'a [u8]>,
49 pub key_type: &'a str,
50 pub layer_index: i32,
51 pub head_index: Option<i32>,
52 pub access_count: i32,
53 pub last_accessed_seconds_ago: f32,
54 pub value_size_bytes: usize,
55}
56
57pub struct CacheEntryScorer {
59 key_engagement: HashMap<String, f32>, config: CacheScoringConfig,
61}
62
63#[derive(Debug, Clone)]
64pub struct CacheScoringConfig {
65 pub recency_weight: f32,
66 pub access_count_weight: f32,
67 pub key_pattern_weight: f32,
68 pub layer_weight: f32,
69 pub head_weight: f32,
70 pub value_size_weight: f32,
71 pub engagement_decay: f32,
72 pub min_engagement: f32,
73 pub max_engagement: f32,
74}
75
76impl Default for CacheScoringConfig {
77 fn default() -> Self {
78 Self {
79 recency_weight: 0.3,
80 access_count_weight: 0.2,
81 key_pattern_weight: 0.25,
82 layer_weight: 0.1,
83 head_weight: 0.05,
84 value_size_weight: 0.1,
85 engagement_decay: 0.95,
86 min_engagement: 0.1,
87 max_engagement: 1.0,
88 }
89 }
90}
91
92impl CacheEntryScorer {
93 pub fn new(config: CacheScoringConfig) -> Self {
95 Self {
96 key_engagement: HashMap::new(),
97 config,
98 }
99 }
100
101 pub fn score_entry(&self, params: CacheEntryParams) -> f32 {
103 let mut score = 0.0;
104
105 score += self.score_recency(params.last_accessed_seconds_ago);
106 score += self.score_access_count(params.access_count);
107 score += self.score_key_patterns(params.key_data, params.key_type);
108 score += self.score_layer_position(params.layer_index);
109 score += self.score_head_position(params.head_index);
110 score += self.score_value_size(params.value_size_bytes);
111 score += self.score_key_engagement(params.key_hash);
112
113 score.clamp(0.0, 1.0)
114 }
115
116 fn score_recency(&self, seconds_ago: f32) -> f32 {
117 let recency_factor = 1.0 / (1.0 + seconds_ago / 3600.0); recency_factor * self.config.recency_weight
119 }
120
121 fn score_access_count(&self, access_count: i32) -> f32 {
122 let normalized = (access_count as f32).min(100.0) / 100.0;
123 normalized * self.config.access_count_weight
124 }
125
126 fn score_key_patterns(&self, key_data: Option<&[u8]>, key_type: &str) -> f32 {
127 let mut pattern_score: f32 = 0.0;
129
130 match key_type {
132 "attention_key" | "attention_value" => pattern_score += 0.1,
133 "ffn_key" | "ffn_value" => pattern_score += 0.05,
134 _ => {}
135 }
136
137 if let Some(data) = key_data {
139 if let Ok(key_str) = std::str::from_utf8(data) {
140 for (pattern_name, regex) in KEY_PATTERNS.iter() {
141 if regex.is_match(key_str) {
142 let weight = match *pattern_name {
143 "system_prompt" => 0.8,
144 "code_related" => 0.7,
145 "important_concept" => 0.9,
146 "question" => 0.6,
147 "numeric" => 0.5,
148 _ => 0.3,
149 };
150 pattern_score += weight;
151 }
152 }
153 }
154 }
155
156 pattern_score.min(1.0) * self.config.key_pattern_weight
157 }
158
159 fn score_layer_position(&self, layer_index: i32) -> f32 {
160 let layer_factor = if layer_index < 10 {
162 0.9
163 } else if layer_index < 20 {
164 0.7
165 } else {
166 0.5
167 };
168 layer_factor * self.config.layer_weight
169 }
170
171 fn score_head_position(&self, head_index: Option<i32>) -> f32 {
172 if let Some(head) = head_index {
173 let head_factor = if head < 4 { 0.8 } else { 0.5 };
175 head_factor * self.config.head_weight
176 } else {
177 0.0
178 }
179 }
180
181 fn score_value_size(&self, size_bytes: usize) -> f32 {
182 let size_factor = (size_bytes as f32).min(10000.0) / 10000.0;
184 size_factor * self.config.value_size_weight
185 }
186
187 fn score_key_engagement(&self, key_hash: &str) -> f32 {
188 self.key_engagement.get(key_hash).map_or(0.0, |&e| e * 0.3)
189 }
190
191 pub fn update_engagement(&mut self, key_hash: &str, was_retrieved: bool) {
192 let engagement_increase = if was_retrieved { 0.15 } else { 0.05 };
193
194 let current = self.key_engagement.entry(key_hash.to_string()).or_insert(0.3);
195 *current = (*current + engagement_increase)
196 .min(self.config.max_engagement)
197 .max(self.config.min_engagement);
198
199 self.decay_other_keys(key_hash);
201 }
202
203 fn decay_other_keys(&mut self, current_key: &str) {
204 for (key, engagement) in self.key_engagement.iter_mut() {
205 if *key != current_key {
206 *engagement = (*engagement * self.config.engagement_decay)
207 .max(self.config.min_engagement);
208 }
209 }
210 }
211
212 pub fn should_preserve_entry(
214 &self,
215 importance_score: f32,
216 key_type: &str,
217 layer_index: i32,
218 config_threshold: f32,
219 ) -> bool {
220 let base_preservation = match key_type {
221 "attention_key" | "attention_value" => 0.8,
222 "ffn_key" | "ffn_value" => 0.6,
223 _ => 0.5,
224 };
225
226 let layer_factor = if layer_index < 8 { 1.2 } else { 1.0 };
227 let combined_score = importance_score * layer_factor;
228
229 combined_score >= config_threshold || base_preservation >= 0.7
230 }
231
232 pub fn extract_keywords(&self, key_data: Option<&[u8]>) -> Vec<String> {
234 let mut keywords = Vec::new();
235
236 if let Some(data) = key_data {
237 if let Ok(key_str) = std::str::from_utf8(data) {
238 let words: Vec<&str> = key_str.split_whitespace().collect();
240 for word in words.iter().filter(|w| w.len() > 3) {
241 let word_lower = word.to_lowercase();
242
243 if !self.is_stop_word(&word_lower) {
245 keywords.push(word_lower);
246 }
247 }
248 }
249 }
250
251 keywords.dedup();
252 keywords.truncate(5); keywords
254 }
255
256 fn is_stop_word(&self, word: &str) -> bool {
257 let stop_words = [
258 "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
259 "of", "with", "by", "is", "am", "are", "was", "were", "be", "been",
260 "being", "have", "has", "had", "do", "does", "did", "will", "would",
261 "shall", "should", "may", "might", "must", "can", "could", "this",
262 "that", "these", "those", "it", "its", "it's",
263 ];
264 stop_words.contains(&word)
265 }
266}
267
268pub fn score_message_importance(role: &str, content: &str) -> f32 {
274 let role_base: f32 = match role {
277 "system" => 0.9,
278 "assistant" => 0.6,
279 _ => 0.4, };
281
282 let mut content_bonus: f32 = 0.0;
284
285 if content.contains("```") {
287 content_bonus += 0.2;
288 }
289
290 for (pattern_name, regex) in KEY_PATTERNS.iter() {
292 if regex.is_match(content) {
293 content_bonus += match *pattern_name {
294 "important_concept" => 0.15,
295 "code_related" => 0.10,
296 "system_prompt" => 0.10,
297 "question" => 0.05,
298 "numeric" => 0.04,
299 _ => 0.02,
300 };
301 }
302 }
303 let content_bonus = content_bonus.min(0.35);
305
306 let length_bonus = ((content.len() as f32) / 3000.0).min(0.1);
308
309 (role_base + content_bonus + length_bonus).clamp(0.1, 0.95)
310}
311
312impl crate::cache_management::cache_extractor::CacheEntryScorer for CacheEntryScorer {
314 fn extract_keywords(&self, key_data: Option<&[u8]>) -> Vec<String> {
315 self.extract_keywords(key_data)
317 }
318}