Skip to main content

offline_intelligence/context_engine/
retrieval_planner.rs

1use crate::memory::Message;
2use crate::memory_db::MemoryDatabase;
3use std::sync::Arc;
4use tracing::{debug, info};
5
6/// Plan for retrieving content from memory
7#[derive(Debug, Clone)]
8pub struct RetrievalPlan {
9    /// Whether to retrieve from memory at all
10    pub needs_retrieval: bool,
11    
12    /// Which memory tiers to use
13    pub use_tier1: bool,  // Current KV cache
14    pub use_tier2: bool,  // Summarized content
15    pub use_tier3: bool,  // Full database
16    
17    /// Whether to search across different sessions
18    pub cross_session_search: bool,
19    
20    /// Search strategies to employ
21    pub semantic_search: bool,
22    pub keyword_search: bool,
23    pub temporal_search: bool,
24    
25    /// Limits for retrieval
26    pub max_messages: usize,
27    pub max_tokens: usize,
28    
29    /// Target compression ratio if summarizing
30    pub target_compression: f32,
31    
32    /// Specific topics to search for
33    pub search_topics: Vec<String>,
34}
35
36impl Default for RetrievalPlan {
37    fn default() -> Self {
38        Self {
39            needs_retrieval: false,
40            use_tier1: true,
41            use_tier2: false,
42            use_tier3: false,
43            cross_session_search: false,
44            semantic_search: false,
45            keyword_search: false,
46            temporal_search: false,
47            max_messages: 100,
48            max_tokens: 4000,
49            target_compression: 0.3,
50            search_topics: Vec::new(),
51        }
52    }
53}
54
55/// Plans retrieval strategies based on conversation context
56pub struct RetrievalPlanner {
57    database: Arc<MemoryDatabase>,
58    recent_threshold_messages: usize,
59    max_retrieval_time_ms: u64,
60}
61
62impl RetrievalPlanner {
63    /// Create a new retrieval planner
64    pub fn new(database: Arc<MemoryDatabase>) -> Self {
65        Self {
66            database,
67            recent_threshold_messages: 20,
68            max_retrieval_time_ms: 200,
69        }
70    }
71    
72    /// Analyze conversation and create retrieval plan
73    pub async fn create_plan(
74        &self,
75        session_id: &str,
76        current_messages: &[Message],
77        max_context_tokens: usize,
78        user_query: Option<&str>,
79        has_past_refs: bool, // NEW parameter
80    ) -> anyhow::Result<RetrievalPlan> {
81        let mut plan = RetrievalPlan {
82            max_tokens: max_context_tokens,
83            ..Default::default()
84        };
85        
86        // Check user query first for past references
87        let mut has_past_references_in_query = false;
88        if let Some(query) = user_query {
89            // Check for cross-session references
90            if self.is_cross_session_query(query, session_id) {
91                plan.needs_retrieval = true;
92                plan.cross_session_search = true;
93                plan.search_topics = self.extract_topics_from_query(query);
94            }
95            
96            // Check for past references in the CURRENT query
97            has_past_references_in_query = self.has_past_references_in_text(query);
98        }
99
100        // Also use the passed has_past_refs parameter if available
101        if !has_past_references_in_query && has_past_refs {
102            has_past_references_in_query = true;
103        }
104
105        // Check if we need retrieval based on context window limits
106        if !plan.needs_retrieval && !self.needs_retrieval(current_messages, max_context_tokens) {
107            // Even if within limits, check if query asks for past content
108            if has_past_references_in_query {
109                plan.needs_retrieval = true;
110                debug!("Retrieval needed: query asks for past content");
111            } else {
112                debug!("No retrieval needed - within context limits and no past references");
113                return Ok(plan);
114            }
115        }
116        
117        plan.needs_retrieval = true;
118        
119        // Always use current context (Tier 1)
120        plan.use_tier1 = true;
121        
122        // Analyze conversation to determine retrieval strategy
123        let analysis = self.analyze_conversation(current_messages, user_query).await?;
124        
125        // Determine which tiers to use based on analysis
126        self.plan_tier_usage(&mut plan, &analysis, session_id, has_past_references_in_query).await?;
127        
128        // Determine search strategies
129        self.plan_search_strategies(&mut plan, &analysis, user_query);
130        
131        // Extract search topics from analysis if not already set by cross-session logic
132        if plan.search_topics.is_empty() {
133            plan.search_topics = analysis.extracted_topics;
134        }
135        
136        // Adjust limits based on available tokens
137        self.adjust_limits(&mut plan, current_messages, max_context_tokens);
138        
139        info!(
140            "Created retrieval plan: Tiers({}{}{}), CrossSession({}), Search({}{}{}), PastRefs={}",
141            if plan.use_tier1 { "1" } else { "" },
142            if plan.use_tier2 { "2" } else { "" },
143            if plan.use_tier3 { "3" } else { "" },
144            plan.cross_session_search,
145            if plan.semantic_search { "S" } else { "" },
146            if plan.keyword_search { "K" } else { "" },
147            if plan.temporal_search { "T" } else { "" },
148            has_past_references_in_query
149        );
150        
151        Ok(plan)
152    }
153    
154    /// Check if retrieval is needed based on message volume
155    fn needs_retrieval(&self, messages: &[Message], max_tokens: usize) -> bool {
156        if messages.len() <= 1 {
157            return false;
158        }
159        
160        // Estimate tokens
161        let estimated_tokens: usize = messages.iter()
162            .map(|m| m.content.len() / 4)
163            .sum();
164        
165        estimated_tokens > max_tokens
166    }
167
168    /// Detect if the user query is asking for information from other sessions
169    fn is_cross_session_query(&self, query: &str, _current_session_id: &str) -> bool {
170        let cross_session_patterns = [
171            "previously", "before", "earlier", "last time", "yesterday",
172            "do you remember", "we discussed", "we talked about",
173            "what did we talk", "remember when", "recall",
174        ];
175        
176        let query_lower = query.to_lowercase();
177        
178        // Check for explicit cross-session references
179        cross_session_patterns.iter().any(|pattern| query_lower.contains(pattern))
180    }
181    
182    /// Check for past references in ANY text (not just current messages)
183    pub fn has_past_references_in_text(&self, text: &str) -> bool {  // CHANGED: made public
184        let reference_patterns = [
185            "earlier", "before", "previous", "last time", "yesterday",
186            "we discussed", "we talked about", "remember", "recall",
187            "did we talk", "have we discussed", "what did we say",
188            "what was said", "mentioned earlier", "previously mentioned",
189        ];
190        
191        let text_lower = text.to_lowercase();
192        reference_patterns.iter().any(|p| text_lower.contains(p))
193    }
194
195    /// Helper to extract topics directly from a single query string
196    fn extract_topics_from_query(&self, query: &str) -> Vec<String> {
197        let words: Vec<&str> = query.split_whitespace().collect();
198        if words.len() < 3 {
199            return vec![query.to_string()];
200        }
201        
202        // Simple extraction logic: take the last few words as the topic
203        let topic = words.iter()
204            .rev()
205            .take(4)
206            .rev()
207            .copied()
208            .collect::<Vec<&str>>()
209            .join(" ");
210            
211        vec![topic]
212    }
213    
214    /// Analyze conversation context
215    async fn analyze_conversation(
216        &self,
217        messages: &[Message],
218        user_query: Option<&str>,
219    ) -> anyhow::Result<ConversationAnalysis> {
220        let mut analysis = ConversationAnalysis {
221            extracted_topics: self.extract_topics(messages),
222            has_past_references: self.has_past_references_in_messages(messages),
223            ..Default::default()
224        };
225        
226        // Check if query asks for specific information
227        if let Some(query) = user_query {
228            analysis.requires_specific_details = self.requires_specific_details(query);
229            analysis.query_complexity = self.assess_query_complexity(query);
230        }
231        
232        // Analyze conversation length and patterns
233        analysis.conversation_length = messages.len();
234        analysis.recency_pattern = self.analyze_recency_pattern(messages);
235        
236        Ok(analysis)
237    }
238    
239    /// Plan which memory tiers to use
240    async fn plan_tier_usage(
241        &self,
242        plan: &mut RetrievalPlan,
243        analysis: &ConversationAnalysis,
244        session_id: &str,
245        has_past_references_in_query: bool,
246    ) -> anyhow::Result<()> {
247        let has_summaries = self.database.summaries
248            .get_session_summaries(session_id)
249            .map(|summaries| !summaries.is_empty())
250            .unwrap_or_else(|e| {
251                debug!("Database error checking summaries: {}", e);
252                false
253            });
254        
255        plan.use_tier2 = has_summaries;
256        
257        // NEW: Check if we have messages in database for this session
258        let has_db_messages = self.check_if_session_has_db_messages(session_id).await?;
259        
260        // TIER 3 LOGIC FIXED:
261        // 1. Always use Tier 3 if query asks for past content (regardless of conversation length)
262        if has_past_references_in_query && has_db_messages {
263            plan.use_tier3 = true;
264            debug!("Query asks for past content, using Tier 3 (database)");
265        }
266        
267        // 2. Use Tier 3 for specific details
268        if analysis.requires_specific_details && has_db_messages {
269            plan.use_tier3 = true;
270            debug!("Specific details requested, using Tier 3");
271        }
272        
273        // 3. Use Tier 3 for cross-session search
274        if plan.cross_session_search {
275            plan.use_tier3 = true;
276            debug!("Cross-session search, using Tier 3");
277        }
278        
279        // 4. Use Tier 3 for long conversations (for summarization)
280        if analysis.conversation_length > 30 && has_db_messages && !plan.use_tier3 {
281            plan.use_tier3 = true;
282            debug!("Long conversation ({} messages), using Tier 3", analysis.conversation_length);
283        }
284        
285        // 5. Use Tier 3 if we have past references in recent messages
286        if analysis.has_past_references && has_db_messages && !plan.use_tier3 {
287            plan.use_tier3 = true;
288            debug!("Past references in messages, using Tier 3");
289        }
290        
291        if analysis.conversation_length > 100 {
292            plan.target_compression = 0.2;
293        }
294        
295        Ok(())
296    }
297    
298    /// Check if session has messages in database
299    async fn check_if_session_has_db_messages(&self, session_id: &str) -> anyhow::Result<bool> {
300        // Quick check: get just 1 message to see if session exists in DB
301        match self.database.conversations.get_session_messages(session_id, Some(1), Some(0)) {
302            Ok(messages) => Ok(!messages.is_empty()),
303            Err(e) => {
304                debug!("Error checking DB for session {}: {}", session_id, e);
305                Ok(false)
306            }
307        }
308    }
309    
310    /// Plan search strategies
311    fn plan_search_strategies(
312        &self,
313        plan: &mut RetrievalPlan,
314        analysis: &ConversationAnalysis,
315        user_query: Option<&str>,
316    ) {
317        // Semantic search for complex queries or when topics are unclear
318        plan.semantic_search = analysis.query_complexity > 0.5 || (analysis.extracted_topics.is_empty() && !plan.cross_session_search);
319        
320        // Keyword search for specific references or cross-session topic matches
321        plan.keyword_search = analysis.requires_specific_details 
322            || analysis.has_past_references 
323            || plan.cross_session_search
324            || !analysis.extracted_topics.is_empty();
325        
326        // Temporal search for time-based references
327        plan.temporal_search = self.has_temporal_references(user_query.unwrap_or(""));
328    }
329    
330    /// Adjust limits based on available context
331    fn adjust_limits(
332        &self,
333        plan: &mut RetrievalPlan,
334        current_messages: &[Message],
335        max_context_tokens: usize,
336    ) {
337        let current_tokens: usize = current_messages.iter()
338            .map(|m| m.content.len() / 4)
339            .sum();
340        
341        let available_for_retrieval = max_context_tokens.saturating_sub(current_tokens);
342        
343        // Assume ~50 tokens per message on average
344        let estimated_messages = available_for_retrieval / 50;
345        plan.max_messages = estimated_messages.clamp(10, 100);
346    }
347    
348    /// Extract topics from messages
349    fn extract_topics(&self, messages: &[Message]) -> Vec<String> {
350        let mut topics = Vec::new();
351        
352        for message in messages.iter().rev().filter(|m| m.role == "user").take(3) {
353            let words: Vec<&str> = message.content.split_whitespace().collect();
354            
355            for i in 0..words.len().saturating_sub(2) {
356                let word_lower = words[i].to_lowercase();
357                
358                if word_lower == "about" || word_lower == "regarding" {
359                    let topic = words[i + 1..].iter()
360                        .take(3)
361                        .copied()
362                        .collect::<Vec<&str>>()
363                        .join(" ");
364                    
365                    if !topic.is_empty() {
366                        topics.push(topic);
367                    }
368                }
369                
370                if ["what", "how", "why", "when", "where", "who"].contains(&word_lower.as_str()) {
371                    let topic = words[i + 1..].iter()
372                        .take(4)
373                        .copied()
374                        .collect::<Vec<&str>>()
375                        .join(" ");
376                    
377                    if !topic.is_empty() {
378                        topics.push(topic);
379                    }
380                }
381            }
382        }
383        
384        topics.dedup();
385        topics.truncate(3);
386        
387        topics
388    }
389    
390    /// Check for references to past content in messages (renamed for clarity)
391    fn has_past_references_in_messages(&self, messages: &[Message]) -> bool {
392        let reference_patterns = [
393            "earlier", "before", "previous", "last time", "yesterday",
394            "we discussed", "we talked about", "remember", "recall",
395        ];
396        
397        for message in messages.iter().rev().take(5) {
398            let content_lower = message.content.to_lowercase();
399            if reference_patterns.iter().any(|p| content_lower.contains(p)) {
400                return true;
401            }
402        }
403        
404        false
405    }
406    
407    /// Check if query requires specific details
408    fn requires_specific_details(&self, query: &str) -> bool {
409        let detail_patterns = [
410            "exactly", "specifically", "in detail", "step by step",
411            "the code", "the number", "the date", "the name",
412            "show me", "give me", "tell me",
413        ];
414        
415        let query_lower = query.to_lowercase();
416        detail_patterns.iter().any(|p| query_lower.contains(p))
417    }
418    
419    /// Assess query complexity
420    fn assess_query_complexity(&self, query: &str) -> f32 {
421        let words: Vec<&str> = query.split_whitespace().collect();
422        
423        if words.len() < 3 {
424            return 0.2;
425        }
426        
427        let mut complexity = 0.0;
428        complexity += (words.len() as f32).min(50.0) / 100.0;
429        
430        let clause_count = query.split(&[',', ';', '&']).count();
431        complexity += (clause_count as f32).min(5.0) / 10.0;
432        
433        let technical_terms = ["code", "function", "algorithm", "parameter", "variable"];
434        for term in technical_terms {
435            if query.to_lowercase().contains(term) {
436                complexity += 0.2;
437            }
438        }
439        
440        complexity.min(1.0)
441    }
442    
443    /// Analyze recency pattern
444    fn analyze_recency_pattern(&self, messages: &[Message]) -> RecencyPattern {
445        if messages.len() < 5 {
446            return RecencyPattern::RecentOnly;
447        }
448        
449        let recent_topics = self.extract_topics(&messages[messages.len().saturating_sub(5)..]);
450        let older_topics = self.extract_topics(&messages[..messages.len().saturating_sub(5)]);
451        
452        let overlap = recent_topics.iter()
453            .filter(|topic| older_topics.contains(topic))
454            .count();
455        
456        match overlap {
457            0 => RecencyPattern::TopicJumping,
458            1 => RecencyPattern::Mixed,
459            _ => RecencyPattern::TopicContinuation,
460        }
461    }
462    
463    /// Check for temporal references
464    fn has_temporal_references(&self, query: &str) -> bool {
465        let temporal_patterns = [
466            "yesterday", "today", "tomorrow", "last week", "last month",
467            "earlier", "before", "previously", "in the past",
468        ];
469        
470        let query_lower = query.to_lowercase();
471        temporal_patterns.iter().any(|p| query_lower.contains(p))
472    }
473}
474
475/// Analysis of conversation context
476#[derive(Debug, Default)]
477struct ConversationAnalysis {
478    extracted_topics: Vec<String>,
479    has_past_references: bool,
480    requires_specific_details: bool,
481    query_complexity: f32,
482    conversation_length: usize,
483    recency_pattern: RecencyPattern,
484}
485
486/// Pattern of message recency
487#[derive(Debug, Clone, PartialEq, Default)]
488enum RecencyPattern {
489    #[default]
490    RecentOnly,
491    TopicContinuation,
492    TopicJumping,
493    Mixed,
494}
495
496impl Clone for RetrievalPlanner {
497    fn clone(&self) -> Self {
498        Self {
499            database: self.database.clone(),
500            recent_threshold_messages: self.recent_threshold_messages,
501            max_retrieval_time_ms: self.max_retrieval_time_ms,
502        }
503    }
504}