offline_intelligence/context_engine/
smart_retrieval.rs1use crate::memory::Message;
11use crate::memory_db::StoredMessage;
12use crate::context_engine::tier_manager::TierManager;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use tracing::{info, debug};
16
17#[derive(Debug, Clone)]
19pub struct SmartRetrievalConfig {
20 pub max_retrieved_tokens: usize,
22
23 pub importance_threshold: f32,
25
26 pub chunk_contiguous_messages: bool,
28
29 pub enabled: bool,
31}
32
33impl Default for SmartRetrievalConfig {
34 fn default() -> Self {
35 Self {
36 max_retrieved_tokens: 1000,
37 importance_threshold: 0.5,
38 chunk_contiguous_messages: true,
39 enabled: true,
40 }
41 }
42}
43
44impl SmartRetrievalConfig {
45 pub fn from_ctx_size(ctx_size: u32) -> Self {
49 Self {
50 max_retrieved_tokens: (ctx_size as f32 * 0.25) as usize,
51 ..Self::default()
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct RetrievalResult {
59 pub strategy: RetrievalStrategy,
61
62 pub messages: Vec<Message>,
64
65 pub compute_savings: f32,
67
68 pub retrieved_tokens: usize,
70
71 pub sessions_referenced: Vec<String>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub enum RetrievalStrategy {
78 HotCacheHit,
80
81 ImportanceFiltered,
83
84 FullRetrieval,
86
87 NoRetrieval,
89}
90
91pub struct SmartRetrieval {
93 tier_manager: Arc<RwLock<TierManager>>,
94 config: SmartRetrievalConfig,
95}
96
97impl SmartRetrieval {
98 pub fn new(tier_manager: Arc<RwLock<TierManager>>, config: SmartRetrievalConfig) -> Self {
100 Self {
101 tier_manager,
102 config,
103 }
104 }
105
106 pub async fn retrieve(
108 &self,
109 session_id: &str,
110 current_messages: &[Message],
111 tier3_messages: Option<Vec<StoredMessage>>,
112 cross_session_messages: Option<Vec<StoredMessage>>,
113 ) -> anyhow::Result<RetrievalResult> {
114 if !self.config.enabled {
115 debug!("Smart retrieval disabled, using fallback");
116 return self.fallback_retrieval(current_messages);
117 }
118
119 let tier_manager = self.tier_manager.read().await;
121 if let Some(hot_messages) = tier_manager.get_tier1_content(session_id).await {
122 let retrieved_tokens = self.count_tokens(&hot_messages);
123 info!("🚀 Smart retrieval: Tier 1 hot cache hit for session {}", session_id);
124 return Ok(RetrievalResult {
125 strategy: RetrievalStrategy::HotCacheHit,
126 messages: hot_messages,
127 compute_savings: 1.0,
128 retrieved_tokens,
129 sessions_referenced: vec![session_id.to_string()],
130 });
131 }
132 drop(tier_manager);
133
134 let has_tier3 = tier3_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
136 let has_cross_session = cross_session_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
137
138 if !has_tier3 && !has_cross_session {
139 debug!("No historical content available, returning current messages");
140 return Ok(RetrievalResult {
141 strategy: RetrievalStrategy::NoRetrieval,
142 messages: current_messages.to_vec(),
143 compute_savings: 0.0,
144 retrieved_tokens: 0,
145 sessions_referenced: vec![],
146 });
147 }
148
149 let optimized_context = self.build_context_from_tiers(
151 current_messages,
152 tier3_messages.as_ref(),
153 cross_session_messages.as_ref(),
154 ).await?;
155
156 let strategy = if self.config.importance_threshold > 0.0 {
157 RetrievalStrategy::ImportanceFiltered
158 } else {
159 RetrievalStrategy::FullRetrieval
160 };
161
162 let compute_savings = self.estimate_compute_savings(&strategy, &optimized_context.messages);
163
164 info!(
165 "Smart retrieval complete: Strategy={:?}, Tokens={}, Savings={:.1}%",
166 strategy,
167 optimized_context.retrieved_tokens,
168 compute_savings * 100.0
169 );
170
171 Ok(optimized_context)
172 }
173
174 async fn build_context_from_tiers(
176 &self,
177 current_messages: &[Message],
178 tier3_messages: Option<&Vec<StoredMessage>>,
179 cross_session_messages: Option<&Vec<StoredMessage>>,
180 ) -> anyhow::Result<RetrievalResult> {
181 let mut context = Vec::new();
182 let mut retrieved_tokens = 0;
183 let mut sessions_referenced = Vec::new();
184
185 let current_tokens: usize = current_messages.iter()
186 .map(|m| self.estimate_message_tokens(m))
187 .sum();
188
189 let budget_for_history = self.config.max_retrieved_tokens.saturating_sub(current_tokens);
190
191 if let Some(cross_msgs) = cross_session_messages {
193 if !cross_msgs.is_empty() {
194 let cross_context = self.add_cross_session_context(cross_msgs, budget_for_history / 3);
195 retrieved_tokens += self.count_tokens(&cross_context);
196
197 for msg in cross_msgs.iter().take(3) {
198 if !sessions_referenced.contains(&msg.session_id) {
199 sessions_referenced.push(msg.session_id.clone());
200 }
201 }
202
203 context.extend(cross_context);
204 }
205 }
206
207 if let Some(tier3_msgs) = tier3_messages {
209 let remaining_budget = budget_for_history.saturating_sub(retrieved_tokens);
210 let detail_context = self.add_important_details(tier3_msgs, remaining_budget);
211 retrieved_tokens += self.count_tokens(&detail_context);
212 context.extend(detail_context);
213 }
214
215 context.extend_from_slice(current_messages);
217
218 Ok(RetrievalResult {
219 strategy: RetrievalStrategy::ImportanceFiltered,
220 messages: context,
221 compute_savings: 0.0,
222 retrieved_tokens,
223 sessions_referenced,
224 })
225 }
226
227 fn add_cross_session_context(
229 &self,
230 cross_messages: &[StoredMessage],
231 token_budget: usize,
232 ) -> Vec<Message> {
233 let mut context = Vec::new();
234 let mut used_tokens = 0;
235
236 context.push(Message {
238 role: "system".to_string(),
239 content: "[Context from previous conversations]".to_string(),
240 });
241 used_tokens += 8;
242
243 let mut scored: Vec<_> = cross_messages.iter()
245 .map(|m| (m, m.importance_score))
246 .collect();
247 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
248
249 for (msg, _score) in scored.iter().take(3) {
250 let msg_tokens = msg.tokens as usize;
251 if used_tokens + msg_tokens > token_budget {
252 break;
253 }
254
255 context.push(Message {
256 role: msg.role.clone(),
257 content: format!("[From earlier: {}]", msg.content),
258 });
259 used_tokens += msg_tokens;
260 }
261
262 debug!("Added {} cross-session messages ({} tokens)", context.len() - 1, used_tokens);
263 context
264 }
265
266 fn add_important_details(
268 &self,
269 messages: &[StoredMessage],
270 token_budget: usize,
271 ) -> Vec<Message> {
272 let mut context = Vec::new();
273 let mut used_tokens = 0;
274
275 let important: Vec<_> = messages.iter()
277 .filter(|m| m.importance_score >= self.config.importance_threshold)
278 .collect();
279
280 if important.is_empty() {
281 debug!("No messages meet importance threshold {}", self.config.importance_threshold);
282 return context;
283 }
284
285 let mut scored = important.clone();
287 scored.sort_by(|a, b| b.importance_score.partial_cmp(&a.importance_score).unwrap_or(std::cmp::Ordering::Equal));
288
289 for msg in scored {
291 let msg_tokens = msg.tokens as usize;
292 if used_tokens + msg_tokens > token_budget {
293 break;
294 }
295
296 context.push(Message {
297 role: msg.role.clone(),
298 content: msg.content.clone(),
299 });
300 used_tokens += msg_tokens;
301 }
302
303 info!("Added {} important messages ({} tokens, threshold={:.2})",
304 context.len(),
305 used_tokens,
306 self.config.importance_threshold
307 );
308
309 context
310 }
311
312 fn estimate_compute_savings(&self, strategy: &RetrievalStrategy, _messages: &[Message]) -> f32 {
314 match strategy {
315 RetrievalStrategy::HotCacheHit => 1.0,
316 RetrievalStrategy::ImportanceFiltered => 0.6,
317 RetrievalStrategy::FullRetrieval => 0.0,
318 RetrievalStrategy::NoRetrieval => 0.0,
319 }
320 }
321
322 fn count_tokens(&self, messages: &[Message]) -> usize {
324 messages.iter()
325 .map(|m| self.estimate_message_tokens(m))
326 .sum()
327 }
328
329 fn estimate_message_tokens(&self, message: &Message) -> usize {
331 message.content.len() / 4
332 }
333
334 fn fallback_retrieval(&self, current_messages: &[Message]) -> anyhow::Result<RetrievalResult> {
336 Ok(RetrievalResult {
337 strategy: RetrievalStrategy::FullRetrieval,
338 messages: current_messages.to_vec(),
339 compute_savings: 0.0,
340 retrieved_tokens: 0,
341 sessions_referenced: vec![],
342 })
343 }
344}
345
346impl Clone for SmartRetrieval {
347 fn clone(&self) -> Self {
348 Self {
349 tier_manager: Arc::clone(&self.tier_manager),
350 config: self.config.clone(),
351 }
352 }
353}