1use crate::memory::Message;
2use crate::memory_db::MemoryDatabase;
3use std::sync::Arc;
4use tracing::{debug, info};
5
6#[derive(Debug, Clone)]
8pub struct RetrievalPlan {
9 pub needs_retrieval: bool,
11
12 pub use_tier1: bool, pub use_tier3: bool, pub cross_session_search: bool,
18
19 pub semantic_search: bool,
21 pub keyword_search: bool,
22 pub temporal_search: bool,
23
24 pub max_messages: usize,
26 pub max_tokens: usize,
27
28 pub search_topics: Vec<String>,
30}
31
32impl Default for RetrievalPlan {
33 fn default() -> Self {
34 Self {
35 needs_retrieval: false,
36 use_tier1: true,
37 use_tier3: false,
38 cross_session_search: false,
39 semantic_search: false,
40 keyword_search: false,
41 temporal_search: false,
42 max_messages: 100,
43 max_tokens: 4000,
44 search_topics: Vec::new(),
45 }
46 }
47}
48
49pub struct RetrievalPlanner {
51 database: Arc<MemoryDatabase>,
52 recent_threshold_messages: usize,
53 max_retrieval_time_ms: u64,
54}
55
56impl RetrievalPlanner {
57 pub fn new(database: Arc<MemoryDatabase>) -> Self {
59 Self {
60 database,
61 recent_threshold_messages: 20,
62 max_retrieval_time_ms: 200,
63 }
64 }
65
66 pub async fn create_plan(
68 &self,
69 session_id: &str,
70 current_messages: &[Message],
71 max_context_tokens: usize,
72 user_query: Option<&str>,
73 has_past_refs: bool, ) -> anyhow::Result<RetrievalPlan> {
75 let mut plan = RetrievalPlan {
76 max_tokens: max_context_tokens,
77 ..Default::default()
78 };
79
80 let mut has_past_references_in_query = false;
82 if let Some(query) = user_query {
83 if self.is_cross_session_query(query, session_id) {
85 plan.needs_retrieval = true;
86 plan.cross_session_search = true;
87 plan.search_topics = self.extract_topics_from_query(query);
88 }
89
90 has_past_references_in_query = self.has_past_references_in_text(query);
92 }
93
94 if !has_past_references_in_query && has_past_refs {
96 has_past_references_in_query = true;
97 }
98
99 if !plan.needs_retrieval && !self.needs_retrieval(current_messages, max_context_tokens) {
101 if has_past_references_in_query {
103 plan.needs_retrieval = true;
104 debug!("Retrieval needed: query asks for past content");
105 } else {
106 debug!("No retrieval needed - within context limits and no past references");
107 return Ok(plan);
108 }
109 }
110
111 plan.needs_retrieval = true;
112
113 plan.use_tier1 = true;
115
116 let analysis = self.analyze_conversation(current_messages, user_query).await?;
118
119 self.plan_tier_usage(&mut plan, &analysis, session_id, has_past_references_in_query).await?;
121
122 self.plan_search_strategies(&mut plan, &analysis, user_query);
124
125 if plan.search_topics.is_empty() {
127 plan.search_topics = analysis.extracted_topics;
128 }
129
130 self.adjust_limits(&mut plan, current_messages, max_context_tokens);
132
133 info!(
134 "Created retrieval plan: Tiers({}{}), CrossSession({}), Search({}{}{}), PastRefs={}",
135 if plan.use_tier1 { "1" } else { "" },
136 if plan.use_tier3 { "3" } else { "" },
137 plan.cross_session_search,
138 if plan.semantic_search { "S" } else { "" },
139 if plan.keyword_search { "K" } else { "" },
140 if plan.temporal_search { "T" } else { "" },
141 has_past_references_in_query
142 );
143
144 Ok(plan)
145 }
146
147 fn needs_retrieval(&self, messages: &[Message], max_tokens: usize) -> bool {
149 if messages.len() <= 1 {
150 return false;
151 }
152
153 let estimated_tokens: usize = messages.iter()
155 .map(|m| m.content.len() / 4)
156 .sum();
157
158 estimated_tokens > max_tokens
159 }
160
161 fn is_cross_session_query(&self, query: &str, _current_session_id: &str) -> bool {
163 let cross_session_patterns = [
164 "previously", "before", "earlier", "last time", "yesterday",
165 "do you remember", "we discussed", "we talked about",
166 "what did we talk", "remember when", "recall",
167 ];
168
169 let query_lower = query.to_lowercase();
170
171 cross_session_patterns.iter().any(|pattern| query_lower.contains(pattern))
173 }
174
175 pub fn has_past_references_in_text(&self, text: &str) -> bool { let reference_patterns = [
178 "earlier", "before", "previous", "last time", "yesterday",
179 "we discussed", "we talked about", "remember", "recall",
180 "did we talk", "have we discussed", "what did we say",
181 "what was said", "mentioned earlier", "previously mentioned",
182 ];
183
184 let text_lower = text.to_lowercase();
185 reference_patterns.iter().any(|p| text_lower.contains(p))
186 }
187
188 fn extract_topics_from_query(&self, query: &str) -> Vec<String> {
190 let words: Vec<&str> = query.split_whitespace().collect();
191 if words.len() < 3 {
192 return vec![query.to_string()];
193 }
194
195 let topic = words.iter()
197 .rev()
198 .take(4)
199 .rev()
200 .copied()
201 .collect::<Vec<&str>>()
202 .join(" ");
203
204 vec![topic]
205 }
206
207 async fn analyze_conversation(
209 &self,
210 messages: &[Message],
211 user_query: Option<&str>,
212 ) -> anyhow::Result<ConversationAnalysis> {
213 let mut analysis = ConversationAnalysis {
214 extracted_topics: self.extract_topics(messages),
215 has_past_references: self.has_past_references_in_messages(messages),
216 ..Default::default()
217 };
218
219 if let Some(query) = user_query {
221 analysis.requires_specific_details = self.requires_specific_details(query);
222 analysis.query_complexity = self.assess_query_complexity(query);
223 }
224
225 analysis.conversation_length = messages.len();
227 analysis.recency_pattern = self.analyze_recency_pattern(messages);
228
229 Ok(analysis)
230 }
231
232 async fn plan_tier_usage(
234 &self,
235 plan: &mut RetrievalPlan,
236 analysis: &ConversationAnalysis,
237 session_id: &str,
238 has_past_references_in_query: bool,
239 ) -> anyhow::Result<()> {
240 let has_db_messages = self.check_if_session_has_db_messages(session_id).await?;
241
242 if has_past_references_in_query && has_db_messages {
245 plan.use_tier3 = true;
246 debug!("Query asks for past content, using Tier 3 (database)");
247 }
248
249 if analysis.requires_specific_details && has_db_messages {
251 plan.use_tier3 = true;
252 debug!("Specific details requested, using Tier 3");
253 }
254
255 if plan.cross_session_search {
257 plan.use_tier3 = true;
258 debug!("Cross-session search, using Tier 3");
259 }
260
261 if analysis.conversation_length > 30 && has_db_messages && !plan.use_tier3 {
263 plan.use_tier3 = true;
264 debug!("Long conversation ({} messages), using Tier 3", analysis.conversation_length);
265 }
266
267 if analysis.has_past_references && has_db_messages && !plan.use_tier3 {
269 plan.use_tier3 = true;
270 debug!("Past references in messages, using Tier 3");
271 }
272
273 Ok(())
274 }
275
276 async fn check_if_session_has_db_messages(&self, session_id: &str) -> anyhow::Result<bool> {
278 match self.database.conversations.get_session_messages(session_id, Some(1), Some(0)) {
280 Ok(messages) => Ok(!messages.is_empty()),
281 Err(e) => {
282 debug!("Error checking DB for session {}: {}", session_id, e);
283 Ok(false)
284 }
285 }
286 }
287
288 fn plan_search_strategies(
290 &self,
291 plan: &mut RetrievalPlan,
292 analysis: &ConversationAnalysis,
293 user_query: Option<&str>,
294 ) {
295 plan.semantic_search = analysis.query_complexity > 0.5 || (analysis.extracted_topics.is_empty() && !plan.cross_session_search);
297
298 plan.keyword_search = analysis.requires_specific_details
300 || analysis.has_past_references
301 || plan.cross_session_search
302 || !analysis.extracted_topics.is_empty();
303
304 plan.temporal_search = self.has_temporal_references(user_query.unwrap_or(""));
306 }
307
308 fn adjust_limits(
310 &self,
311 plan: &mut RetrievalPlan,
312 current_messages: &[Message],
313 max_context_tokens: usize,
314 ) {
315 let current_tokens: usize = current_messages.iter()
316 .map(|m| m.content.len() / 4)
317 .sum();
318
319 let available_for_retrieval = max_context_tokens.saturating_sub(current_tokens);
320
321 let estimated_messages = available_for_retrieval / 50;
323 plan.max_messages = estimated_messages.clamp(10, 100);
324 }
325
326 fn extract_topics(&self, messages: &[Message]) -> Vec<String> {
328 let mut topics = Vec::new();
329
330 for message in messages.iter().rev().filter(|m| m.role == "user").take(3) {
331 let words: Vec<&str> = message.content.split_whitespace().collect();
332
333 for i in 0..words.len().saturating_sub(2) {
334 let word_lower = words[i].to_lowercase();
335
336 if word_lower == "about" || word_lower == "regarding" {
337 let topic = words[i + 1..].iter()
338 .take(3)
339 .copied()
340 .collect::<Vec<&str>>()
341 .join(" ");
342
343 if !topic.is_empty() {
344 topics.push(topic);
345 }
346 }
347
348 if ["what", "how", "why", "when", "where", "who"].contains(&word_lower.as_str()) {
349 let topic = words[i + 1..].iter()
350 .take(4)
351 .copied()
352 .collect::<Vec<&str>>()
353 .join(" ");
354
355 if !topic.is_empty() {
356 topics.push(topic);
357 }
358 }
359 }
360 }
361
362 topics.dedup();
363 topics.truncate(3);
364
365 topics
366 }
367
368 fn has_past_references_in_messages(&self, messages: &[Message]) -> bool {
370 let reference_patterns = [
371 "earlier", "before", "previous", "last time", "yesterday",
372 "we discussed", "we talked about", "remember", "recall",
373 ];
374
375 for message in messages.iter().rev().take(5) {
376 let content_lower = message.content.to_lowercase();
377 if reference_patterns.iter().any(|p| content_lower.contains(p)) {
378 return true;
379 }
380 }
381
382 false
383 }
384
385 fn requires_specific_details(&self, query: &str) -> bool {
387 let detail_patterns = [
388 "exactly", "specifically", "in detail", "step by step",
389 "the code", "the number", "the date", "the name",
390 "show me", "give me", "tell me",
391 ];
392
393 let query_lower = query.to_lowercase();
394 detail_patterns.iter().any(|p| query_lower.contains(p))
395 }
396
397 fn assess_query_complexity(&self, query: &str) -> f32 {
399 let words: Vec<&str> = query.split_whitespace().collect();
400
401 if words.len() < 3 {
402 return 0.2;
403 }
404
405 let mut complexity = 0.0;
406 complexity += (words.len() as f32).min(50.0) / 100.0;
407
408 let clause_count = query.split(&[',', ';', '&']).count();
409 complexity += (clause_count as f32).min(5.0) / 10.0;
410
411 let technical_terms = ["code", "function", "algorithm", "parameter", "variable"];
412 for term in technical_terms {
413 if query.to_lowercase().contains(term) {
414 complexity += 0.2;
415 }
416 }
417
418 complexity.min(1.0)
419 }
420
421 fn analyze_recency_pattern(&self, messages: &[Message]) -> RecencyPattern {
423 if messages.len() < 5 {
424 return RecencyPattern::RecentOnly;
425 }
426
427 let recent_topics = self.extract_topics(&messages[messages.len().saturating_sub(5)..]);
428 let older_topics = self.extract_topics(&messages[..messages.len().saturating_sub(5)]);
429
430 let overlap = recent_topics.iter()
431 .filter(|topic| older_topics.contains(topic))
432 .count();
433
434 match overlap {
435 0 => RecencyPattern::TopicJumping,
436 1 => RecencyPattern::Mixed,
437 _ => RecencyPattern::TopicContinuation,
438 }
439 }
440
441 fn has_temporal_references(&self, query: &str) -> bool {
443 let temporal_patterns = [
444 "yesterday", "today", "tomorrow", "last week", "last month",
445 "earlier", "before", "previously", "in the past",
446 ];
447
448 let query_lower = query.to_lowercase();
449 temporal_patterns.iter().any(|p| query_lower.contains(p))
450 }
451}
452
453#[derive(Debug, Default)]
455struct ConversationAnalysis {
456 extracted_topics: Vec<String>,
457 has_past_references: bool,
458 requires_specific_details: bool,
459 query_complexity: f32,
460 conversation_length: usize,
461 recency_pattern: RecencyPattern,
462}
463
464#[derive(Debug, Clone, PartialEq, Default)]
466enum RecencyPattern {
467 #[default]
468 RecentOnly,
469 TopicContinuation,
470 TopicJumping,
471 Mixed,
472}
473
474impl Clone for RetrievalPlanner {
475 fn clone(&self) -> Self {
476 Self {
477 database: self.database.clone(),
478 recent_threshold_messages: self.recent_threshold_messages,
479 max_retrieval_time_ms: self.max_retrieval_time_ms,
480 }
481 }
482}