Skip to main content

offline_intelligence/context_engine/
smart_retrieval.rs

1//! Smart retrieval with hierarchical context optimization
2//!
3//! This module implements intelligent retrieval that minimizes recomputation cost
4//! by preferring summaries over raw messages and enforcing strict token budgets.
5//!
6//! Key optimizations:
7//! - Tier 1 (hot cache) → O(1) return
8//! - Tier 2 (summaries) → Compressed context (50 tokens vs 2000)
9//! - Tier 3 (database) → Importance-filtered, token-budgeted
10//! - Hierarchical assembly → Summary first, details on-demand
11
12use crate::memory::Message;
13use crate::memory_db::{StoredMessage, Summary as DbSummary};
14use crate::context_engine::tier_manager::TierManager;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17use tracing::{info, debug};
18
19/// Configuration for smart retrieval
20#[derive(Debug, Clone)]
21pub struct SmartRetrievalConfig {
22    /// Maximum tokens for retrieved historical context (excludes current messages)
23    pub max_retrieved_tokens: usize,
24
25    /// Prefer summaries over raw messages when available
26    pub prefer_summaries: bool,
27
28    /// Minimum importance score to include a message (0.0-1.0)
29    pub importance_threshold: f32,
30
31    /// Group contiguous messages into chunks for better llama.cpp caching
32    pub chunk_contiguous_messages: bool,
33
34    /// Use hierarchical context (summary first, then details)
35    pub use_hierarchical_context: bool,
36
37    /// Enable smart retrieval (can be disabled to fall back to original behavior)
38    pub enabled: bool,
39}
40
41impl Default for SmartRetrievalConfig {
42    fn default() -> Self {
43        Self {
44            max_retrieved_tokens: 1000,  // Budget for old context
45            prefer_summaries: true,       // Always prefer compressed context
46            importance_threshold: 0.5,    // Only retrieve important messages
47            chunk_contiguous_messages: true,
48            use_hierarchical_context: true,
49            enabled: true,
50        }
51    }
52}
53
54/// Result of smart retrieval operation
55#[derive(Debug, Clone)]
56pub struct RetrievalResult {
57    /// Strategy used for retrieval
58    pub strategy: RetrievalStrategy,
59
60    /// Optimized messages to send to LLM
61    pub messages: Vec<Message>,
62
63    /// Estimated computation cost saved (0.0-1.0)
64    pub compute_savings: f32,
65
66    /// Token count of retrieved context
67    pub retrieved_tokens: usize,
68
69    /// Sessions referenced in the retrieval
70    pub sessions_referenced: Vec<String>,
71}
72
73/// Strategy used for retrieval
74#[derive(Debug, Clone, PartialEq)]
75pub enum RetrievalStrategy {
76    /// Current session already in Tier 1 hot cache
77    HotCacheHit,
78
79    /// Used summaries to compress old context
80    SummaryCompression,
81
82    /// Retrieved chunks with importance filtering
83    ImportanceFiltered,
84
85    /// Full retrieval (fallback, no optimization)
86    FullRetrieval,
87
88    /// No retrieval needed (fresh context)
89    NoRetrieval,
90}
91
92/// Smart retrieval orchestrator
93pub struct SmartRetrieval {
94    tier_manager: Arc<RwLock<TierManager>>,
95    config: SmartRetrievalConfig,
96}
97
98impl SmartRetrieval {
99    /// Create a new smart retrieval instance
100    pub fn new(tier_manager: Arc<RwLock<TierManager>>, config: SmartRetrievalConfig) -> Self {
101        Self {
102            tier_manager,
103            config,
104        }
105    }
106
107    /// Main retrieval function with smart optimization
108    pub async fn retrieve(
109        &self,
110        session_id: &str,
111        current_messages: &[Message],
112        tier2_summaries: Option<Vec<DbSummary>>,
113        tier3_messages: Option<Vec<StoredMessage>>,
114        cross_session_messages: Option<Vec<StoredMessage>>,
115    ) -> anyhow::Result<RetrievalResult> {
116        if !self.config.enabled {
117            debug!("Smart retrieval disabled, using fallback");
118            return self.fallback_retrieval(current_messages);
119        }
120
121        // Step 1: Check Tier 1 hot cache
122        let tier_manager = self.tier_manager.read().await;
123        if let Some(hot_messages) = tier_manager.get_tier1_content(session_id).await {
124            let retrieved_tokens = self.count_tokens(&hot_messages);
125            info!("🚀 Smart retrieval: Tier 1 hot cache hit for session {}", session_id);
126            return Ok(RetrievalResult {
127                strategy: RetrievalStrategy::HotCacheHit,
128                messages: hot_messages,
129                compute_savings: 1.0,  // 100% savings - no recomputation
130                retrieved_tokens,
131                sessions_referenced: vec![session_id.to_string()],
132            });
133        }
134        drop(tier_manager);
135
136        // Step 2: Check if we have any historical content
137        let has_summaries = tier2_summaries.as_ref().map(|s| !s.is_empty()).unwrap_or(false);
138        let has_tier3 = tier3_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
139        let has_cross_session = cross_session_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
140
141        if !has_summaries && !has_tier3 && !has_cross_session {
142            debug!("No historical content available, returning current messages");
143            return Ok(RetrievalResult {
144                strategy: RetrievalStrategy::NoRetrieval,
145                messages: current_messages.to_vec(),
146                compute_savings: 0.0,
147                retrieved_tokens: 0,
148                sessions_referenced: vec![],
149            });
150        }
151
152        // Step 3: Build optimized context based on available tiers
153        let optimized_context = if self.config.use_hierarchical_context {
154            self.build_hierarchical_context(
155                current_messages,
156                tier2_summaries.as_ref(),
157                tier3_messages.as_ref(),
158                cross_session_messages.as_ref(),
159            ).await?
160        } else {
161            self.build_standard_context(
162                current_messages,
163                tier3_messages.as_ref(),
164                cross_session_messages.as_ref(),
165            ).await?
166        };
167
168        // Determine strategy based on what was used
169        let strategy = if has_summaries && self.config.prefer_summaries {
170            RetrievalStrategy::SummaryCompression
171        } else if self.config.importance_threshold > 0.0 {
172            RetrievalStrategy::ImportanceFiltered
173        } else {
174            RetrievalStrategy::FullRetrieval
175        };
176
177        // Calculate compute savings
178        let compute_savings = self.estimate_compute_savings(&strategy, &optimized_context.messages);
179
180        info!(
181            "Smart retrieval complete: Strategy={:?}, Tokens={}, Savings={:.1}%",
182            strategy,
183            optimized_context.retrieved_tokens,
184            compute_savings * 100.0
185        );
186
187        Ok(optimized_context)
188    }
189
190    /// Build hierarchical context (summary first, then details)
191    async fn build_hierarchical_context(
192        &self,
193        current_messages: &[Message],
194        tier2_summaries: Option<&Vec<DbSummary>>,
195        tier3_messages: Option<&Vec<StoredMessage>>,
196        cross_session_messages: Option<&Vec<StoredMessage>>,
197    ) -> anyhow::Result<RetrievalResult> {
198        let mut context = Vec::new();
199        let mut retrieved_tokens = 0;
200        let mut sessions_referenced = Vec::new();
201
202        // Reserve budget for current messages
203        let current_tokens: usize = current_messages.iter()
204            .map(|m| self.estimate_message_tokens(m))
205            .sum();
206
207        let budget_for_history = self.config.max_retrieved_tokens.saturating_sub(current_tokens);
208
209        // Step 1: Add cross-session context if available (highest priority)
210        if let Some(cross_msgs) = cross_session_messages {
211            if !cross_msgs.is_empty() {
212                let cross_context = self.add_cross_session_context(
213                    cross_msgs,
214                    budget_for_history / 3,  // Allocate 1/3 budget for cross-session
215                );
216                retrieved_tokens += self.count_tokens(&cross_context);
217
218                // Track unique sessions
219                for msg in cross_msgs.iter().take(3) {
220                    if !sessions_referenced.contains(&msg.session_id) {
221                        sessions_referenced.push(msg.session_id.clone());
222                    }
223                }
224
225                context.extend(cross_context);
226            }
227        }
228
229        // Step 2: Add summaries if available and preferred
230        if self.config.prefer_summaries {
231            if let Some(summaries) = tier2_summaries {
232                if !summaries.is_empty() {
233                    let summary_context = self.add_summary_context(
234                        summaries,
235                        budget_for_history.saturating_sub(retrieved_tokens),
236                    );
237                    retrieved_tokens += self.count_tokens(&summary_context);
238                    context.extend(summary_context);
239
240                    info!("📋 Used {} summaries (compressed context)", summaries.len());
241                }
242            }
243        }
244
245        // Step 3: Add important details from Tier 3 if budget allows
246        if retrieved_tokens < budget_for_history {
247            if let Some(tier3_msgs) = tier3_messages {
248                let remaining_budget = budget_for_history.saturating_sub(retrieved_tokens);
249                let detail_context = self.add_important_details(
250                    tier3_msgs,
251                    remaining_budget,
252                );
253                retrieved_tokens += self.count_tokens(&detail_context);
254                context.extend(detail_context);
255            }
256        }
257
258        // Step 4: Add current messages (always included, full detail)
259        context.extend_from_slice(current_messages);
260
261        Ok(RetrievalResult {
262            strategy: RetrievalStrategy::SummaryCompression,
263            messages: context,
264            compute_savings: 0.0,  // Will be calculated by caller
265            retrieved_tokens,
266            sessions_referenced,
267        })
268    }
269
270    /// Build standard context (importance-filtered only)
271    async fn build_standard_context(
272        &self,
273        current_messages: &[Message],
274        tier3_messages: Option<&Vec<StoredMessage>>,
275        cross_session_messages: Option<&Vec<StoredMessage>>,
276    ) -> anyhow::Result<RetrievalResult> {
277        let mut context = Vec::new();
278        let mut retrieved_tokens = 0;
279        let mut sessions_referenced = Vec::new();
280
281        // Calculate budget
282        let current_tokens: usize = current_messages.iter()
283            .map(|m| self.estimate_message_tokens(m))
284            .sum();
285
286        let budget_for_history = self.config.max_retrieved_tokens.saturating_sub(current_tokens);
287
288        // Add cross-session messages
289        if let Some(cross_msgs) = cross_session_messages {
290            if !cross_msgs.is_empty() {
291                let cross_context = self.add_cross_session_context(cross_msgs, budget_for_history / 2);
292                retrieved_tokens += self.count_tokens(&cross_context);
293
294                for msg in cross_msgs.iter().take(3) {
295                    if !sessions_referenced.contains(&msg.session_id) {
296                        sessions_referenced.push(msg.session_id.clone());
297                    }
298                }
299
300                context.extend(cross_context);
301            }
302        }
303
304        // Add filtered tier3 messages
305        if let Some(tier3_msgs) = tier3_messages {
306            let remaining_budget = budget_for_history.saturating_sub(retrieved_tokens);
307            let detail_context = self.add_important_details(tier3_msgs, remaining_budget);
308            retrieved_tokens += self.count_tokens(&detail_context);
309            context.extend(detail_context);
310        }
311
312        // Add current messages
313        context.extend_from_slice(current_messages);
314
315        Ok(RetrievalResult {
316            strategy: RetrievalStrategy::ImportanceFiltered,
317            messages: context,
318            compute_savings: 0.0,
319            retrieved_tokens,
320            sessions_referenced,
321        })
322    }
323
324    /// Add cross-session context with budget enforcement
325    fn add_cross_session_context(
326        &self,
327        cross_messages: &[StoredMessage],
328        token_budget: usize,
329    ) -> Vec<Message> {
330        let mut context = Vec::new();
331        let mut used_tokens = 0;
332
333        // Add bridge message
334        context.push(Message {
335            role: "system".to_string(),
336            content: "[Context from previous conversations]".to_string(),
337        });
338        used_tokens += 8;
339
340        // Add top 3 most important cross-session messages
341        let mut scored: Vec<_> = cross_messages.iter()
342            .map(|m| (m, m.importance_score))
343            .collect();
344        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
345
346        for (msg, _score) in scored.iter().take(3) {
347            let msg_tokens = msg.tokens as usize;
348            if used_tokens + msg_tokens > token_budget {
349                break;
350            }
351
352            context.push(Message {
353                role: msg.role.clone(),
354                content: format!("[From earlier: {}]", msg.content),
355            });
356            used_tokens += msg_tokens;
357        }
358
359        debug!("Added {} cross-session messages ({} tokens)", context.len() - 1, used_tokens);
360        context
361    }
362
363    /// Add summary context with compression
364    fn add_summary_context(
365        &self,
366        summaries: &[DbSummary],
367        token_budget: usize,
368    ) -> Vec<Message> {
369        let mut context = Vec::new();
370        let mut used_tokens = 0;
371
372        // Sort summaries by relevance (compression ratio as proxy for importance)
373        let mut scored: Vec<_> = summaries.iter()
374            .map(|s| (s, s.compression_ratio))
375            .collect();
376        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
377
378        // Track best compression for logging
379        let best_compression = scored.first().map(|(s, _)| s.compression_ratio).unwrap_or(1.0);
380
381        for (summary, _score) in scored.iter() {
382            let summary_tokens = summary.summary_text.len() / 4;
383            if used_tokens + summary_tokens > token_budget {
384                break;
385            }
386
387            context.push(Message {
388                role: "system".to_string(),
389                content: format!("[Summary: {}]", summary.summary_text),
390            });
391            used_tokens += summary_tokens;
392        }
393
394        info!("Added {} summaries ({} tokens, compression saved {}%)",
395              context.len(),
396              used_tokens,
397              (1.0 - best_compression) * 100.0
398        );
399
400        context
401    }
402
403    /// Add important details with importance filtering and budget
404    fn add_important_details(
405        &self,
406        messages: &[StoredMessage],
407        token_budget: usize,
408    ) -> Vec<Message> {
409        let mut context = Vec::new();
410        let mut used_tokens = 0;
411
412        // Filter by importance threshold
413        let important: Vec<_> = messages.iter()
414            .filter(|m| m.importance_score >= self.config.importance_threshold)
415            .collect();
416
417        if important.is_empty() {
418            debug!("No messages meet importance threshold {}", self.config.importance_threshold);
419            return context;
420        }
421
422        // Sort by importance score descending
423        let mut scored = important.clone();
424        scored.sort_by(|a, b| b.importance_score.partial_cmp(&a.importance_score).unwrap_or(std::cmp::Ordering::Equal));
425
426        // Add messages until budget exhausted
427        for msg in scored {
428            let msg_tokens = msg.tokens as usize;
429            if used_tokens + msg_tokens > token_budget {
430                break;
431            }
432
433            context.push(Message {
434                role: msg.role.clone(),
435                content: msg.content.clone(),
436            });
437            used_tokens += msg_tokens;
438        }
439
440        info!("Added {} important messages ({} tokens, threshold={:.2})",
441              context.len(),
442              used_tokens,
443              self.config.importance_threshold
444        );
445
446        context
447    }
448
449    /// Estimate compute savings based on strategy
450    fn estimate_compute_savings(&self, strategy: &RetrievalStrategy, messages: &[Message]) -> f32 {
451        match strategy {
452            RetrievalStrategy::HotCacheHit => 1.0,  // 100% savings (cached in RAM)
453            RetrievalStrategy::SummaryCompression => {
454                // Estimate based on token reduction
455                // If we compressed 2000 tokens to 50, that's 97.5% savings
456                let total_tokens = self.count_tokens(messages);
457                if total_tokens < 100 {
458                    0.95  // High compression
459                } else if total_tokens < 500 {
460                    0.75  // Medium compression
461                } else {
462                    0.5   // Low compression
463                }
464            }
465            RetrievalStrategy::ImportanceFiltered => {
466                // Savings from filtering out unimportant messages
467                0.6  // ~60% savings from importance filtering
468            }
469            RetrievalStrategy::FullRetrieval => 0.0,  // No savings
470            RetrievalStrategy::NoRetrieval => 0.0,    // No savings (but also no cost)
471        }
472    }
473
474    /// Count total tokens in messages
475    fn count_tokens(&self, messages: &[Message]) -> usize {
476        messages.iter()
477            .map(|m| self.estimate_message_tokens(m))
478            .sum()
479    }
480
481    /// Estimate tokens for a message (rough approximation: 4 chars per token)
482    fn estimate_message_tokens(&self, message: &Message) -> usize {
483        message.content.len() / 4
484    }
485
486    /// Fallback retrieval (disabled smart retrieval)
487    fn fallback_retrieval(&self, current_messages: &[Message]) -> anyhow::Result<RetrievalResult> {
488        Ok(RetrievalResult {
489            strategy: RetrievalStrategy::FullRetrieval,
490            messages: current_messages.to_vec(),
491            compute_savings: 0.0,
492            retrieved_tokens: 0,
493            sessions_referenced: vec![],
494        })
495    }
496}
497
498impl Clone for SmartRetrieval {
499    fn clone(&self) -> Self {
500        Self {
501            tier_manager: Arc::clone(&self.tier_manager),
502            config: self.config.clone(),
503        }
504    }
505}