Skip to main content

offline_intelligence/context_engine/
retrieval_planner.rs

1use crate::memory::Message;
2use crate::memory_db::MemoryDatabase;
3use std::sync::Arc;
4use tracing::{debug, info};
5
6/// Plan for retrieving content from memory
7#[derive(Debug, Clone)]
8pub struct RetrievalPlan {
9    /// Whether to retrieve from memory at all
10    pub needs_retrieval: bool,
11
12    /// Which memory tiers to use
13    pub use_tier1: bool,  // Hot in-memory cache
14    pub use_tier3: bool,  // Full SQLite database
15
16    /// Whether to search across different sessions
17    pub cross_session_search: bool,
18
19    /// Search strategies to employ
20    pub semantic_search: bool,
21    pub keyword_search: bool,
22    pub temporal_search: bool,
23
24    /// Limits for retrieval
25    pub max_messages: usize,
26    pub max_tokens: usize,
27
28    /// Specific topics to search for
29    pub search_topics: Vec<String>,
30}
31
32impl Default for RetrievalPlan {
33    fn default() -> Self {
34        Self {
35            needs_retrieval: false,
36            use_tier1: true,
37            use_tier3: false,
38            cross_session_search: false,
39            semantic_search: false,
40            keyword_search: false,
41            temporal_search: false,
42            max_messages: 100,
43            max_tokens: 4000,
44            search_topics: Vec::new(),
45        }
46    }
47}
48
49/// Plans retrieval strategies based on conversation context
50pub struct RetrievalPlanner {
51    database: Arc<MemoryDatabase>,
52    recent_threshold_messages: usize,
53    max_retrieval_time_ms: u64,
54}
55
56impl RetrievalPlanner {
57    /// Create a new retrieval planner
58    pub fn new(database: Arc<MemoryDatabase>) -> Self {
59        Self {
60            database,
61            recent_threshold_messages: 20,
62            max_retrieval_time_ms: 200,
63        }
64    }
65    
66    /// Analyze conversation and create retrieval plan
67    pub async fn create_plan(
68        &self,
69        session_id: &str,
70        current_messages: &[Message],
71        max_context_tokens: usize,
72        user_query: Option<&str>,
73        has_past_refs: bool, // NEW parameter
74    ) -> anyhow::Result<RetrievalPlan> {
75        let mut plan = RetrievalPlan {
76            max_tokens: max_context_tokens,
77            ..Default::default()
78        };
79        
80        // Check user query first for past references
81        let mut has_past_references_in_query = false;
82        if let Some(query) = user_query {
83            // Check for cross-session references
84            if self.is_cross_session_query(query, session_id) {
85                plan.needs_retrieval = true;
86                plan.cross_session_search = true;
87                plan.search_topics = self.extract_topics_from_query(query);
88            }
89            
90            // Check for past references in the CURRENT query
91            has_past_references_in_query = self.has_past_references_in_text(query);
92        }
93
94        // Also use the passed has_past_refs parameter if available
95        if !has_past_references_in_query && has_past_refs {
96            has_past_references_in_query = true;
97        }
98
99        // Check if we need retrieval based on context window limits
100        if !plan.needs_retrieval && !self.needs_retrieval(current_messages, max_context_tokens) {
101            // Even if within limits, check if query asks for past content
102            if has_past_references_in_query {
103                plan.needs_retrieval = true;
104                debug!("Retrieval needed: query asks for past content");
105            } else {
106                debug!("No retrieval needed - within context limits and no past references");
107                return Ok(plan);
108            }
109        }
110        
111        plan.needs_retrieval = true;
112        
113        // Always use current context (Tier 1)
114        plan.use_tier1 = true;
115        
116        // Analyze conversation to determine retrieval strategy
117        let analysis = self.analyze_conversation(current_messages, user_query).await?;
118        
119        // Determine which tiers to use based on analysis
120        self.plan_tier_usage(&mut plan, &analysis, session_id, has_past_references_in_query).await?;
121        
122        // Determine search strategies
123        self.plan_search_strategies(&mut plan, &analysis, user_query);
124        
125        // Extract search topics from analysis if not already set by cross-session logic
126        if plan.search_topics.is_empty() {
127            plan.search_topics = analysis.extracted_topics;
128        }
129        
130        // Adjust limits based on available tokens
131        self.adjust_limits(&mut plan, current_messages, max_context_tokens);
132        
133        info!(
134            "Created retrieval plan: Tiers({}{}), CrossSession({}), Search({}{}{}), PastRefs={}",
135            if plan.use_tier1 { "1" } else { "" },
136            if plan.use_tier3 { "3" } else { "" },
137            plan.cross_session_search,
138            if plan.semantic_search { "S" } else { "" },
139            if plan.keyword_search { "K" } else { "" },
140            if plan.temporal_search { "T" } else { "" },
141            has_past_references_in_query
142        );
143        
144        Ok(plan)
145    }
146    
147    /// Check if retrieval is needed based on message volume
148    fn needs_retrieval(&self, messages: &[Message], max_tokens: usize) -> bool {
149        if messages.len() <= 1 {
150            return false;
151        }
152        
153        // Estimate tokens
154        let estimated_tokens: usize = messages.iter()
155            .map(|m| m.content.len() / 4)
156            .sum();
157        
158        estimated_tokens > max_tokens
159    }
160
161    /// Detect if the user query is asking for information from other sessions
162    fn is_cross_session_query(&self, query: &str, _current_session_id: &str) -> bool {
163        let cross_session_patterns = [
164            "previously", "before", "earlier", "last time", "yesterday",
165            "do you remember", "we discussed", "we talked about",
166            "what did we talk", "remember when", "recall",
167        ];
168        
169        let query_lower = query.to_lowercase();
170        
171        // Check for explicit cross-session references
172        cross_session_patterns.iter().any(|pattern| query_lower.contains(pattern))
173    }
174    
175    /// Check for past references in ANY text (not just current messages)
176    pub fn has_past_references_in_text(&self, text: &str) -> bool {  // CHANGED: made public
177        let reference_patterns = [
178            "earlier", "before", "previous", "last time", "yesterday",
179            "we discussed", "we talked about", "remember", "recall",
180            "did we talk", "have we discussed", "what did we say",
181            "what was said", "mentioned earlier", "previously mentioned",
182        ];
183        
184        let text_lower = text.to_lowercase();
185        reference_patterns.iter().any(|p| text_lower.contains(p))
186    }
187
188    /// Helper to extract topics directly from a single query string
189    fn extract_topics_from_query(&self, query: &str) -> Vec<String> {
190        let words: Vec<&str> = query.split_whitespace().collect();
191        if words.len() < 3 {
192            return vec![query.to_string()];
193        }
194        
195        // Simple extraction logic: take the last few words as the topic
196        let topic = words.iter()
197            .rev()
198            .take(4)
199            .rev()
200            .copied()
201            .collect::<Vec<&str>>()
202            .join(" ");
203            
204        vec![topic]
205    }
206    
207    /// Analyze conversation context
208    async fn analyze_conversation(
209        &self,
210        messages: &[Message],
211        user_query: Option<&str>,
212    ) -> anyhow::Result<ConversationAnalysis> {
213        let mut analysis = ConversationAnalysis {
214            extracted_topics: self.extract_topics(messages),
215            has_past_references: self.has_past_references_in_messages(messages),
216            ..Default::default()
217        };
218        
219        // Check if query asks for specific information
220        if let Some(query) = user_query {
221            analysis.requires_specific_details = self.requires_specific_details(query);
222            analysis.query_complexity = self.assess_query_complexity(query);
223        }
224        
225        // Analyze conversation length and patterns
226        analysis.conversation_length = messages.len();
227        analysis.recency_pattern = self.analyze_recency_pattern(messages);
228        
229        Ok(analysis)
230    }
231    
232    /// Plan which memory tiers to use
233    async fn plan_tier_usage(
234        &self,
235        plan: &mut RetrievalPlan,
236        analysis: &ConversationAnalysis,
237        session_id: &str,
238        has_past_references_in_query: bool,
239    ) -> anyhow::Result<()> {
240        let has_db_messages = self.check_if_session_has_db_messages(session_id).await?;
241        
242        // TIER 3 LOGIC FIXED:
243        // 1. Always use Tier 3 if query asks for past content (regardless of conversation length)
244        if has_past_references_in_query && has_db_messages {
245            plan.use_tier3 = true;
246            debug!("Query asks for past content, using Tier 3 (database)");
247        }
248        
249        // 2. Use Tier 3 for specific details
250        if analysis.requires_specific_details && has_db_messages {
251            plan.use_tier3 = true;
252            debug!("Specific details requested, using Tier 3");
253        }
254        
255        // 3. Use Tier 3 for cross-session search
256        if plan.cross_session_search {
257            plan.use_tier3 = true;
258            debug!("Cross-session search, using Tier 3");
259        }
260        
261        // 4. Use Tier 3 for long conversations (for summarization)
262        if analysis.conversation_length > 30 && has_db_messages && !plan.use_tier3 {
263            plan.use_tier3 = true;
264            debug!("Long conversation ({} messages), using Tier 3", analysis.conversation_length);
265        }
266        
267        // 5. Use Tier 3 if we have past references in recent messages
268        if analysis.has_past_references && has_db_messages && !plan.use_tier3 {
269            plan.use_tier3 = true;
270            debug!("Past references in messages, using Tier 3");
271        }
272        
273        Ok(())
274    }
275    
276    /// Check if session has messages in database
277    async fn check_if_session_has_db_messages(&self, session_id: &str) -> anyhow::Result<bool> {
278        // Quick check: get just 1 message to see if session exists in DB
279        match self.database.conversations.get_session_messages(session_id, Some(1), Some(0)) {
280            Ok(messages) => Ok(!messages.is_empty()),
281            Err(e) => {
282                debug!("Error checking DB for session {}: {}", session_id, e);
283                Ok(false)
284            }
285        }
286    }
287    
288    /// Plan search strategies
289    fn plan_search_strategies(
290        &self,
291        plan: &mut RetrievalPlan,
292        analysis: &ConversationAnalysis,
293        user_query: Option<&str>,
294    ) {
295        // Semantic search for complex queries or when topics are unclear
296        plan.semantic_search = analysis.query_complexity > 0.5 || (analysis.extracted_topics.is_empty() && !plan.cross_session_search);
297        
298        // Keyword search for specific references or cross-session topic matches
299        plan.keyword_search = analysis.requires_specific_details 
300            || analysis.has_past_references 
301            || plan.cross_session_search
302            || !analysis.extracted_topics.is_empty();
303        
304        // Temporal search for time-based references
305        plan.temporal_search = self.has_temporal_references(user_query.unwrap_or(""));
306    }
307    
308    /// Adjust limits based on available context
309    fn adjust_limits(
310        &self,
311        plan: &mut RetrievalPlan,
312        current_messages: &[Message],
313        max_context_tokens: usize,
314    ) {
315        let current_tokens: usize = current_messages.iter()
316            .map(|m| m.content.len() / 4)
317            .sum();
318        
319        let available_for_retrieval = max_context_tokens.saturating_sub(current_tokens);
320        
321        // Assume ~50 tokens per message on average
322        let estimated_messages = available_for_retrieval / 50;
323        plan.max_messages = estimated_messages.clamp(10, 100);
324    }
325    
326    /// Extract topics from messages
327    fn extract_topics(&self, messages: &[Message]) -> Vec<String> {
328        let mut topics = Vec::new();
329        
330        for message in messages.iter().rev().filter(|m| m.role == "user").take(3) {
331            let words: Vec<&str> = message.content.split_whitespace().collect();
332            
333            for i in 0..words.len().saturating_sub(2) {
334                let word_lower = words[i].to_lowercase();
335                
336                if word_lower == "about" || word_lower == "regarding" {
337                    let topic = words[i + 1..].iter()
338                        .take(3)
339                        .copied()
340                        .collect::<Vec<&str>>()
341                        .join(" ");
342                    
343                    if !topic.is_empty() {
344                        topics.push(topic);
345                    }
346                }
347                
348                if ["what", "how", "why", "when", "where", "who"].contains(&word_lower.as_str()) {
349                    let topic = words[i + 1..].iter()
350                        .take(4)
351                        .copied()
352                        .collect::<Vec<&str>>()
353                        .join(" ");
354                    
355                    if !topic.is_empty() {
356                        topics.push(topic);
357                    }
358                }
359            }
360        }
361        
362        topics.dedup();
363        topics.truncate(3);
364        
365        topics
366    }
367    
368    /// Check for references to past content in messages (renamed for clarity)
369    fn has_past_references_in_messages(&self, messages: &[Message]) -> bool {
370        let reference_patterns = [
371            "earlier", "before", "previous", "last time", "yesterday",
372            "we discussed", "we talked about", "remember", "recall",
373        ];
374        
375        for message in messages.iter().rev().take(5) {
376            let content_lower = message.content.to_lowercase();
377            if reference_patterns.iter().any(|p| content_lower.contains(p)) {
378                return true;
379            }
380        }
381        
382        false
383    }
384    
385    /// Check if query requires specific details
386    fn requires_specific_details(&self, query: &str) -> bool {
387        let detail_patterns = [
388            "exactly", "specifically", "in detail", "step by step",
389            "the code", "the number", "the date", "the name",
390            "show me", "give me", "tell me",
391        ];
392        
393        let query_lower = query.to_lowercase();
394        detail_patterns.iter().any(|p| query_lower.contains(p))
395    }
396    
397    /// Assess query complexity
398    fn assess_query_complexity(&self, query: &str) -> f32 {
399        let words: Vec<&str> = query.split_whitespace().collect();
400        
401        if words.len() < 3 {
402            return 0.2;
403        }
404        
405        let mut complexity = 0.0;
406        complexity += (words.len() as f32).min(50.0) / 100.0;
407        
408        let clause_count = query.split(&[',', ';', '&']).count();
409        complexity += (clause_count as f32).min(5.0) / 10.0;
410        
411        let technical_terms = ["code", "function", "algorithm", "parameter", "variable"];
412        for term in technical_terms {
413            if query.to_lowercase().contains(term) {
414                complexity += 0.2;
415            }
416        }
417        
418        complexity.min(1.0)
419    }
420    
421    /// Analyze recency pattern
422    fn analyze_recency_pattern(&self, messages: &[Message]) -> RecencyPattern {
423        if messages.len() < 5 {
424            return RecencyPattern::RecentOnly;
425        }
426        
427        let recent_topics = self.extract_topics(&messages[messages.len().saturating_sub(5)..]);
428        let older_topics = self.extract_topics(&messages[..messages.len().saturating_sub(5)]);
429        
430        let overlap = recent_topics.iter()
431            .filter(|topic| older_topics.contains(topic))
432            .count();
433        
434        match overlap {
435            0 => RecencyPattern::TopicJumping,
436            1 => RecencyPattern::Mixed,
437            _ => RecencyPattern::TopicContinuation,
438        }
439    }
440    
441    /// Check for temporal references
442    fn has_temporal_references(&self, query: &str) -> bool {
443        let temporal_patterns = [
444            "yesterday", "today", "tomorrow", "last week", "last month",
445            "earlier", "before", "previously", "in the past",
446        ];
447        
448        let query_lower = query.to_lowercase();
449        temporal_patterns.iter().any(|p| query_lower.contains(p))
450    }
451}
452
453/// Analysis of conversation context
454#[derive(Debug, Default)]
455struct ConversationAnalysis {
456    extracted_topics: Vec<String>,
457    has_past_references: bool,
458    requires_specific_details: bool,
459    query_complexity: f32,
460    conversation_length: usize,
461    recency_pattern: RecencyPattern,
462}
463
464/// Pattern of message recency
465#[derive(Debug, Clone, PartialEq, Default)]
466enum RecencyPattern {
467    #[default]
468    RecentOnly,
469    TopicContinuation,
470    TopicJumping,
471    Mixed,
472}
473
474impl Clone for RetrievalPlanner {
475    fn clone(&self) -> Self {
476        Self {
477            database: self.database.clone(),
478            recent_threshold_messages: self.recent_threshold_messages,
479            max_retrieval_time_ms: self.max_retrieval_time_ms,
480        }
481    }
482}