offline_intelligence/context_engine/
smart_retrieval.rs1use crate::memory::Message;
13use crate::memory_db::{StoredMessage, Summary as DbSummary};
14use crate::context_engine::tier_manager::TierManager;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17use tracing::{info, debug};
18
19#[derive(Debug, Clone)]
21pub struct SmartRetrievalConfig {
22 pub max_retrieved_tokens: usize,
24
25 pub prefer_summaries: bool,
27
28 pub importance_threshold: f32,
30
31 pub chunk_contiguous_messages: bool,
33
34 pub use_hierarchical_context: bool,
36
37 pub enabled: bool,
39}
40
41impl Default for SmartRetrievalConfig {
42 fn default() -> Self {
43 Self {
44 max_retrieved_tokens: 1000, prefer_summaries: true, importance_threshold: 0.5, chunk_contiguous_messages: true,
48 use_hierarchical_context: true,
49 enabled: true,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct RetrievalResult {
57 pub strategy: RetrievalStrategy,
59
60 pub messages: Vec<Message>,
62
63 pub compute_savings: f32,
65
66 pub retrieved_tokens: usize,
68
69 pub sessions_referenced: Vec<String>,
71}
72
73#[derive(Debug, Clone, PartialEq)]
75pub enum RetrievalStrategy {
76 HotCacheHit,
78
79 SummaryCompression,
81
82 ImportanceFiltered,
84
85 FullRetrieval,
87
88 NoRetrieval,
90}
91
92pub struct SmartRetrieval {
94 tier_manager: Arc<RwLock<TierManager>>,
95 config: SmartRetrievalConfig,
96}
97
98impl SmartRetrieval {
99 pub fn new(tier_manager: Arc<RwLock<TierManager>>, config: SmartRetrievalConfig) -> Self {
101 Self {
102 tier_manager,
103 config,
104 }
105 }
106
107 pub async fn retrieve(
109 &self,
110 session_id: &str,
111 current_messages: &[Message],
112 tier2_summaries: Option<Vec<DbSummary>>,
113 tier3_messages: Option<Vec<StoredMessage>>,
114 cross_session_messages: Option<Vec<StoredMessage>>,
115 ) -> anyhow::Result<RetrievalResult> {
116 if !self.config.enabled {
117 debug!("Smart retrieval disabled, using fallback");
118 return self.fallback_retrieval(current_messages);
119 }
120
121 let tier_manager = self.tier_manager.read().await;
123 if let Some(hot_messages) = tier_manager.get_tier1_content(session_id).await {
124 let retrieved_tokens = self.count_tokens(&hot_messages);
125 info!("🚀 Smart retrieval: Tier 1 hot cache hit for session {}", session_id);
126 return Ok(RetrievalResult {
127 strategy: RetrievalStrategy::HotCacheHit,
128 messages: hot_messages,
129 compute_savings: 1.0, retrieved_tokens,
131 sessions_referenced: vec![session_id.to_string()],
132 });
133 }
134 drop(tier_manager);
135
136 let has_summaries = tier2_summaries.as_ref().map(|s| !s.is_empty()).unwrap_or(false);
138 let has_tier3 = tier3_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
139 let has_cross_session = cross_session_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
140
141 if !has_summaries && !has_tier3 && !has_cross_session {
142 debug!("No historical content available, returning current messages");
143 return Ok(RetrievalResult {
144 strategy: RetrievalStrategy::NoRetrieval,
145 messages: current_messages.to_vec(),
146 compute_savings: 0.0,
147 retrieved_tokens: 0,
148 sessions_referenced: vec![],
149 });
150 }
151
152 let optimized_context = if self.config.use_hierarchical_context {
154 self.build_hierarchical_context(
155 current_messages,
156 tier2_summaries.as_ref(),
157 tier3_messages.as_ref(),
158 cross_session_messages.as_ref(),
159 ).await?
160 } else {
161 self.build_standard_context(
162 current_messages,
163 tier3_messages.as_ref(),
164 cross_session_messages.as_ref(),
165 ).await?
166 };
167
168 let strategy = if has_summaries && self.config.prefer_summaries {
170 RetrievalStrategy::SummaryCompression
171 } else if self.config.importance_threshold > 0.0 {
172 RetrievalStrategy::ImportanceFiltered
173 } else {
174 RetrievalStrategy::FullRetrieval
175 };
176
177 let compute_savings = self.estimate_compute_savings(&strategy, &optimized_context.messages);
179
180 info!(
181 "Smart retrieval complete: Strategy={:?}, Tokens={}, Savings={:.1}%",
182 strategy,
183 optimized_context.retrieved_tokens,
184 compute_savings * 100.0
185 );
186
187 Ok(optimized_context)
188 }
189
190 async fn build_hierarchical_context(
192 &self,
193 current_messages: &[Message],
194 tier2_summaries: Option<&Vec<DbSummary>>,
195 tier3_messages: Option<&Vec<StoredMessage>>,
196 cross_session_messages: Option<&Vec<StoredMessage>>,
197 ) -> anyhow::Result<RetrievalResult> {
198 let mut context = Vec::new();
199 let mut retrieved_tokens = 0;
200 let mut sessions_referenced = Vec::new();
201
202 let current_tokens: usize = current_messages.iter()
204 .map(|m| self.estimate_message_tokens(m))
205 .sum();
206
207 let budget_for_history = self.config.max_retrieved_tokens.saturating_sub(current_tokens);
208
209 if let Some(cross_msgs) = cross_session_messages {
211 if !cross_msgs.is_empty() {
212 let cross_context = self.add_cross_session_context(
213 cross_msgs,
214 budget_for_history / 3, );
216 retrieved_tokens += self.count_tokens(&cross_context);
217
218 for msg in cross_msgs.iter().take(3) {
220 if !sessions_referenced.contains(&msg.session_id) {
221 sessions_referenced.push(msg.session_id.clone());
222 }
223 }
224
225 context.extend(cross_context);
226 }
227 }
228
229 if self.config.prefer_summaries {
231 if let Some(summaries) = tier2_summaries {
232 if !summaries.is_empty() {
233 let summary_context = self.add_summary_context(
234 summaries,
235 budget_for_history.saturating_sub(retrieved_tokens),
236 );
237 retrieved_tokens += self.count_tokens(&summary_context);
238 context.extend(summary_context);
239
240 info!("📋 Used {} summaries (compressed context)", summaries.len());
241 }
242 }
243 }
244
245 if retrieved_tokens < budget_for_history {
247 if let Some(tier3_msgs) = tier3_messages {
248 let remaining_budget = budget_for_history.saturating_sub(retrieved_tokens);
249 let detail_context = self.add_important_details(
250 tier3_msgs,
251 remaining_budget,
252 );
253 retrieved_tokens += self.count_tokens(&detail_context);
254 context.extend(detail_context);
255 }
256 }
257
258 context.extend_from_slice(current_messages);
260
261 Ok(RetrievalResult {
262 strategy: RetrievalStrategy::SummaryCompression,
263 messages: context,
264 compute_savings: 0.0, retrieved_tokens,
266 sessions_referenced,
267 })
268 }
269
270 async fn build_standard_context(
272 &self,
273 current_messages: &[Message],
274 tier3_messages: Option<&Vec<StoredMessage>>,
275 cross_session_messages: Option<&Vec<StoredMessage>>,
276 ) -> anyhow::Result<RetrievalResult> {
277 let mut context = Vec::new();
278 let mut retrieved_tokens = 0;
279 let mut sessions_referenced = Vec::new();
280
281 let current_tokens: usize = current_messages.iter()
283 .map(|m| self.estimate_message_tokens(m))
284 .sum();
285
286 let budget_for_history = self.config.max_retrieved_tokens.saturating_sub(current_tokens);
287
288 if let Some(cross_msgs) = cross_session_messages {
290 if !cross_msgs.is_empty() {
291 let cross_context = self.add_cross_session_context(cross_msgs, budget_for_history / 2);
292 retrieved_tokens += self.count_tokens(&cross_context);
293
294 for msg in cross_msgs.iter().take(3) {
295 if !sessions_referenced.contains(&msg.session_id) {
296 sessions_referenced.push(msg.session_id.clone());
297 }
298 }
299
300 context.extend(cross_context);
301 }
302 }
303
304 if let Some(tier3_msgs) = tier3_messages {
306 let remaining_budget = budget_for_history.saturating_sub(retrieved_tokens);
307 let detail_context = self.add_important_details(tier3_msgs, remaining_budget);
308 retrieved_tokens += self.count_tokens(&detail_context);
309 context.extend(detail_context);
310 }
311
312 context.extend_from_slice(current_messages);
314
315 Ok(RetrievalResult {
316 strategy: RetrievalStrategy::ImportanceFiltered,
317 messages: context,
318 compute_savings: 0.0,
319 retrieved_tokens,
320 sessions_referenced,
321 })
322 }
323
324 fn add_cross_session_context(
326 &self,
327 cross_messages: &[StoredMessage],
328 token_budget: usize,
329 ) -> Vec<Message> {
330 let mut context = Vec::new();
331 let mut used_tokens = 0;
332
333 context.push(Message {
335 role: "system".to_string(),
336 content: "[Context from previous conversations]".to_string(),
337 });
338 used_tokens += 8;
339
340 let mut scored: Vec<_> = cross_messages.iter()
342 .map(|m| (m, m.importance_score))
343 .collect();
344 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
345
346 for (msg, _score) in scored.iter().take(3) {
347 let msg_tokens = msg.tokens as usize;
348 if used_tokens + msg_tokens > token_budget {
349 break;
350 }
351
352 context.push(Message {
353 role: msg.role.clone(),
354 content: format!("[From earlier: {}]", msg.content),
355 });
356 used_tokens += msg_tokens;
357 }
358
359 debug!("Added {} cross-session messages ({} tokens)", context.len() - 1, used_tokens);
360 context
361 }
362
363 fn add_summary_context(
365 &self,
366 summaries: &[DbSummary],
367 token_budget: usize,
368 ) -> Vec<Message> {
369 let mut context = Vec::new();
370 let mut used_tokens = 0;
371
372 let mut scored: Vec<_> = summaries.iter()
374 .map(|s| (s, s.compression_ratio))
375 .collect();
376 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
377
378 let best_compression = scored.first().map(|(s, _)| s.compression_ratio).unwrap_or(1.0);
380
381 for (summary, _score) in scored.iter() {
382 let summary_tokens = summary.summary_text.len() / 4;
383 if used_tokens + summary_tokens > token_budget {
384 break;
385 }
386
387 context.push(Message {
388 role: "system".to_string(),
389 content: format!("[Summary: {}]", summary.summary_text),
390 });
391 used_tokens += summary_tokens;
392 }
393
394 info!("Added {} summaries ({} tokens, compression saved {}%)",
395 context.len(),
396 used_tokens,
397 (1.0 - best_compression) * 100.0
398 );
399
400 context
401 }
402
403 fn add_important_details(
405 &self,
406 messages: &[StoredMessage],
407 token_budget: usize,
408 ) -> Vec<Message> {
409 let mut context = Vec::new();
410 let mut used_tokens = 0;
411
412 let important: Vec<_> = messages.iter()
414 .filter(|m| m.importance_score >= self.config.importance_threshold)
415 .collect();
416
417 if important.is_empty() {
418 debug!("No messages meet importance threshold {}", self.config.importance_threshold);
419 return context;
420 }
421
422 let mut scored = important.clone();
424 scored.sort_by(|a, b| b.importance_score.partial_cmp(&a.importance_score).unwrap_or(std::cmp::Ordering::Equal));
425
426 for msg in scored {
428 let msg_tokens = msg.tokens as usize;
429 if used_tokens + msg_tokens > token_budget {
430 break;
431 }
432
433 context.push(Message {
434 role: msg.role.clone(),
435 content: msg.content.clone(),
436 });
437 used_tokens += msg_tokens;
438 }
439
440 info!("Added {} important messages ({} tokens, threshold={:.2})",
441 context.len(),
442 used_tokens,
443 self.config.importance_threshold
444 );
445
446 context
447 }
448
449 fn estimate_compute_savings(&self, strategy: &RetrievalStrategy, messages: &[Message]) -> f32 {
451 match strategy {
452 RetrievalStrategy::HotCacheHit => 1.0, RetrievalStrategy::SummaryCompression => {
454 let total_tokens = self.count_tokens(messages);
457 if total_tokens < 100 {
458 0.95 } else if total_tokens < 500 {
460 0.75 } else {
462 0.5 }
464 }
465 RetrievalStrategy::ImportanceFiltered => {
466 0.6 }
469 RetrievalStrategy::FullRetrieval => 0.0, RetrievalStrategy::NoRetrieval => 0.0, }
472 }
473
474 fn count_tokens(&self, messages: &[Message]) -> usize {
476 messages.iter()
477 .map(|m| self.estimate_message_tokens(m))
478 .sum()
479 }
480
481 fn estimate_message_tokens(&self, message: &Message) -> usize {
483 message.content.len() / 4
484 }
485
486 fn fallback_retrieval(&self, current_messages: &[Message]) -> anyhow::Result<RetrievalResult> {
488 Ok(RetrievalResult {
489 strategy: RetrievalStrategy::FullRetrieval,
490 messages: current_messages.to_vec(),
491 compute_savings: 0.0,
492 retrieved_tokens: 0,
493 sessions_referenced: vec![],
494 })
495 }
496}
497
498impl Clone for SmartRetrieval {
499 fn clone(&self) -> Self {
500 Self {
501 tier_manager: Arc::clone(&self.tier_manager),
502 config: self.config.clone(),
503 }
504 }
505}