1use crate::memory::Message;
2use crate::memory_db::MemoryDatabase;
3use std::sync::Arc;
4use tracing::{debug, info};
5
6#[derive(Debug, Clone)]
8pub struct RetrievalPlan {
9 pub needs_retrieval: bool,
11
12 pub use_tier1: bool, pub use_tier2: bool, pub use_tier3: bool, pub cross_session_search: bool,
19
20 pub semantic_search: bool,
22 pub keyword_search: bool,
23 pub temporal_search: bool,
24
25 pub max_messages: usize,
27 pub max_tokens: usize,
28
29 pub target_compression: f32,
31
32 pub search_topics: Vec<String>,
34}
35
36impl Default for RetrievalPlan {
37 fn default() -> Self {
38 Self {
39 needs_retrieval: false,
40 use_tier1: true,
41 use_tier2: false,
42 use_tier3: false,
43 cross_session_search: false,
44 semantic_search: false,
45 keyword_search: false,
46 temporal_search: false,
47 max_messages: 100,
48 max_tokens: 4000,
49 target_compression: 0.3,
50 search_topics: Vec::new(),
51 }
52 }
53}
54
55pub struct RetrievalPlanner {
57 database: Arc<MemoryDatabase>,
58 recent_threshold_messages: usize,
59 max_retrieval_time_ms: u64,
60}
61
62impl RetrievalPlanner {
63 pub fn new(database: Arc<MemoryDatabase>) -> Self {
65 Self {
66 database,
67 recent_threshold_messages: 20,
68 max_retrieval_time_ms: 200,
69 }
70 }
71
72 pub async fn create_plan(
74 &self,
75 session_id: &str,
76 current_messages: &[Message],
77 max_context_tokens: usize,
78 user_query: Option<&str>,
79 has_past_refs: bool, ) -> anyhow::Result<RetrievalPlan> {
81 let mut plan = RetrievalPlan {
82 max_tokens: max_context_tokens,
83 ..Default::default()
84 };
85
86 let mut has_past_references_in_query = false;
88 if let Some(query) = user_query {
89 if self.is_cross_session_query(query, session_id) {
91 plan.needs_retrieval = true;
92 plan.cross_session_search = true;
93 plan.search_topics = self.extract_topics_from_query(query);
94 }
95
96 has_past_references_in_query = self.has_past_references_in_text(query);
98 }
99
100 if !has_past_references_in_query && has_past_refs {
102 has_past_references_in_query = true;
103 }
104
105 if !plan.needs_retrieval && !self.needs_retrieval(current_messages, max_context_tokens) {
107 if has_past_references_in_query {
109 plan.needs_retrieval = true;
110 debug!("Retrieval needed: query asks for past content");
111 } else {
112 debug!("No retrieval needed - within context limits and no past references");
113 return Ok(plan);
114 }
115 }
116
117 plan.needs_retrieval = true;
118
119 plan.use_tier1 = true;
121
122 let analysis = self.analyze_conversation(current_messages, user_query).await?;
124
125 self.plan_tier_usage(&mut plan, &analysis, session_id, has_past_references_in_query).await?;
127
128 self.plan_search_strategies(&mut plan, &analysis, user_query);
130
131 if plan.search_topics.is_empty() {
133 plan.search_topics = analysis.extracted_topics;
134 }
135
136 self.adjust_limits(&mut plan, current_messages, max_context_tokens);
138
139 info!(
140 "Created retrieval plan: Tiers({}{}{}), CrossSession({}), Search({}{}{}), PastRefs={}",
141 if plan.use_tier1 { "1" } else { "" },
142 if plan.use_tier2 { "2" } else { "" },
143 if plan.use_tier3 { "3" } else { "" },
144 plan.cross_session_search,
145 if plan.semantic_search { "S" } else { "" },
146 if plan.keyword_search { "K" } else { "" },
147 if plan.temporal_search { "T" } else { "" },
148 has_past_references_in_query
149 );
150
151 Ok(plan)
152 }
153
154 fn needs_retrieval(&self, messages: &[Message], max_tokens: usize) -> bool {
156 if messages.len() <= 1 {
157 return false;
158 }
159
160 let estimated_tokens: usize = messages.iter()
162 .map(|m| m.content.len() / 4)
163 .sum();
164
165 estimated_tokens > max_tokens
166 }
167
168 fn is_cross_session_query(&self, query: &str, _current_session_id: &str) -> bool {
170 let cross_session_patterns = [
171 "previously", "before", "earlier", "last time", "yesterday",
172 "do you remember", "we discussed", "we talked about",
173 "what did we talk", "remember when", "recall",
174 ];
175
176 let query_lower = query.to_lowercase();
177
178 cross_session_patterns.iter().any(|pattern| query_lower.contains(pattern))
180 }
181
182 pub fn has_past_references_in_text(&self, text: &str) -> bool { let reference_patterns = [
185 "earlier", "before", "previous", "last time", "yesterday",
186 "we discussed", "we talked about", "remember", "recall",
187 "did we talk", "have we discussed", "what did we say",
188 "what was said", "mentioned earlier", "previously mentioned",
189 ];
190
191 let text_lower = text.to_lowercase();
192 reference_patterns.iter().any(|p| text_lower.contains(p))
193 }
194
195 fn extract_topics_from_query(&self, query: &str) -> Vec<String> {
197 let words: Vec<&str> = query.split_whitespace().collect();
198 if words.len() < 3 {
199 return vec![query.to_string()];
200 }
201
202 let topic = words.iter()
204 .rev()
205 .take(4)
206 .rev()
207 .copied()
208 .collect::<Vec<&str>>()
209 .join(" ");
210
211 vec![topic]
212 }
213
214 async fn analyze_conversation(
216 &self,
217 messages: &[Message],
218 user_query: Option<&str>,
219 ) -> anyhow::Result<ConversationAnalysis> {
220 let mut analysis = ConversationAnalysis {
221 extracted_topics: self.extract_topics(messages),
222 has_past_references: self.has_past_references_in_messages(messages),
223 ..Default::default()
224 };
225
226 if let Some(query) = user_query {
228 analysis.requires_specific_details = self.requires_specific_details(query);
229 analysis.query_complexity = self.assess_query_complexity(query);
230 }
231
232 analysis.conversation_length = messages.len();
234 analysis.recency_pattern = self.analyze_recency_pattern(messages);
235
236 Ok(analysis)
237 }
238
239 async fn plan_tier_usage(
241 &self,
242 plan: &mut RetrievalPlan,
243 analysis: &ConversationAnalysis,
244 session_id: &str,
245 has_past_references_in_query: bool,
246 ) -> anyhow::Result<()> {
247 let has_summaries = self.database.summaries
248 .get_session_summaries(session_id)
249 .map(|summaries| !summaries.is_empty())
250 .unwrap_or_else(|e| {
251 debug!("Database error checking summaries: {}", e);
252 false
253 });
254
255 plan.use_tier2 = has_summaries;
256
257 let has_db_messages = self.check_if_session_has_db_messages(session_id).await?;
259
260 if has_past_references_in_query && has_db_messages {
263 plan.use_tier3 = true;
264 debug!("Query asks for past content, using Tier 3 (database)");
265 }
266
267 if analysis.requires_specific_details && has_db_messages {
269 plan.use_tier3 = true;
270 debug!("Specific details requested, using Tier 3");
271 }
272
273 if plan.cross_session_search {
275 plan.use_tier3 = true;
276 debug!("Cross-session search, using Tier 3");
277 }
278
279 if analysis.conversation_length > 30 && has_db_messages && !plan.use_tier3 {
281 plan.use_tier3 = true;
282 debug!("Long conversation ({} messages), using Tier 3", analysis.conversation_length);
283 }
284
285 if analysis.has_past_references && has_db_messages && !plan.use_tier3 {
287 plan.use_tier3 = true;
288 debug!("Past references in messages, using Tier 3");
289 }
290
291 if analysis.conversation_length > 100 {
292 plan.target_compression = 0.2;
293 }
294
295 Ok(())
296 }
297
298 async fn check_if_session_has_db_messages(&self, session_id: &str) -> anyhow::Result<bool> {
300 match self.database.conversations.get_session_messages(session_id, Some(1), Some(0)) {
302 Ok(messages) => Ok(!messages.is_empty()),
303 Err(e) => {
304 debug!("Error checking DB for session {}: {}", session_id, e);
305 Ok(false)
306 }
307 }
308 }
309
310 fn plan_search_strategies(
312 &self,
313 plan: &mut RetrievalPlan,
314 analysis: &ConversationAnalysis,
315 user_query: Option<&str>,
316 ) {
317 plan.semantic_search = analysis.query_complexity > 0.5 || (analysis.extracted_topics.is_empty() && !plan.cross_session_search);
319
320 plan.keyword_search = analysis.requires_specific_details
322 || analysis.has_past_references
323 || plan.cross_session_search
324 || !analysis.extracted_topics.is_empty();
325
326 plan.temporal_search = self.has_temporal_references(user_query.unwrap_or(""));
328 }
329
330 fn adjust_limits(
332 &self,
333 plan: &mut RetrievalPlan,
334 current_messages: &[Message],
335 max_context_tokens: usize,
336 ) {
337 let current_tokens: usize = current_messages.iter()
338 .map(|m| m.content.len() / 4)
339 .sum();
340
341 let available_for_retrieval = max_context_tokens.saturating_sub(current_tokens);
342
343 let estimated_messages = available_for_retrieval / 50;
345 plan.max_messages = estimated_messages.clamp(10, 100);
346 }
347
348 fn extract_topics(&self, messages: &[Message]) -> Vec<String> {
350 let mut topics = Vec::new();
351
352 for message in messages.iter().rev().filter(|m| m.role == "user").take(3) {
353 let words: Vec<&str> = message.content.split_whitespace().collect();
354
355 for i in 0..words.len().saturating_sub(2) {
356 let word_lower = words[i].to_lowercase();
357
358 if word_lower == "about" || word_lower == "regarding" {
359 let topic = words[i + 1..].iter()
360 .take(3)
361 .copied()
362 .collect::<Vec<&str>>()
363 .join(" ");
364
365 if !topic.is_empty() {
366 topics.push(topic);
367 }
368 }
369
370 if ["what", "how", "why", "when", "where", "who"].contains(&word_lower.as_str()) {
371 let topic = words[i + 1..].iter()
372 .take(4)
373 .copied()
374 .collect::<Vec<&str>>()
375 .join(" ");
376
377 if !topic.is_empty() {
378 topics.push(topic);
379 }
380 }
381 }
382 }
383
384 topics.dedup();
385 topics.truncate(3);
386
387 topics
388 }
389
390 fn has_past_references_in_messages(&self, messages: &[Message]) -> bool {
392 let reference_patterns = [
393 "earlier", "before", "previous", "last time", "yesterday",
394 "we discussed", "we talked about", "remember", "recall",
395 ];
396
397 for message in messages.iter().rev().take(5) {
398 let content_lower = message.content.to_lowercase();
399 if reference_patterns.iter().any(|p| content_lower.contains(p)) {
400 return true;
401 }
402 }
403
404 false
405 }
406
407 fn requires_specific_details(&self, query: &str) -> bool {
409 let detail_patterns = [
410 "exactly", "specifically", "in detail", "step by step",
411 "the code", "the number", "the date", "the name",
412 "show me", "give me", "tell me",
413 ];
414
415 let query_lower = query.to_lowercase();
416 detail_patterns.iter().any(|p| query_lower.contains(p))
417 }
418
419 fn assess_query_complexity(&self, query: &str) -> f32 {
421 let words: Vec<&str> = query.split_whitespace().collect();
422
423 if words.len() < 3 {
424 return 0.2;
425 }
426
427 let mut complexity = 0.0;
428 complexity += (words.len() as f32).min(50.0) / 100.0;
429
430 let clause_count = query.split(&[',', ';', '&']).count();
431 complexity += (clause_count as f32).min(5.0) / 10.0;
432
433 let technical_terms = ["code", "function", "algorithm", "parameter", "variable"];
434 for term in technical_terms {
435 if query.to_lowercase().contains(term) {
436 complexity += 0.2;
437 }
438 }
439
440 complexity.min(1.0)
441 }
442
443 fn analyze_recency_pattern(&self, messages: &[Message]) -> RecencyPattern {
445 if messages.len() < 5 {
446 return RecencyPattern::RecentOnly;
447 }
448
449 let recent_topics = self.extract_topics(&messages[messages.len().saturating_sub(5)..]);
450 let older_topics = self.extract_topics(&messages[..messages.len().saturating_sub(5)]);
451
452 let overlap = recent_topics.iter()
453 .filter(|topic| older_topics.contains(topic))
454 .count();
455
456 match overlap {
457 0 => RecencyPattern::TopicJumping,
458 1 => RecencyPattern::Mixed,
459 _ => RecencyPattern::TopicContinuation,
460 }
461 }
462
463 fn has_temporal_references(&self, query: &str) -> bool {
465 let temporal_patterns = [
466 "yesterday", "today", "tomorrow", "last week", "last month",
467 "earlier", "before", "previously", "in the past",
468 ];
469
470 let query_lower = query.to_lowercase();
471 temporal_patterns.iter().any(|p| query_lower.contains(p))
472 }
473}
474
475#[derive(Debug, Default)]
477struct ConversationAnalysis {
478 extracted_topics: Vec<String>,
479 has_past_references: bool,
480 requires_specific_details: bool,
481 query_complexity: f32,
482 conversation_length: usize,
483 recency_pattern: RecencyPattern,
484}
485
486#[derive(Debug, Clone, PartialEq, Default)]
488enum RecencyPattern {
489 #[default]
490 RecentOnly,
491 TopicContinuation,
492 TopicJumping,
493 Mixed,
494}
495
496impl Clone for RetrievalPlanner {
497 fn clone(&self) -> Self {
498 Self {
499 database: self.database.clone(),
500 recent_threshold_messages: self.recent_threshold_messages,
501 max_retrieval_time_ms: self.max_retrieval_time_ms,
502 }
503 }
504}