1use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::agency::{Agent, AgentBuilder, AgentConfig};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SearchContext {
19 pub recent_queries: Vec<String>,
21 pub recent_sessions: Vec<String>,
23 pub workspace_id: Option<String>,
25 pub providers: Vec<String>,
27 pub preferences: SearchPreferences,
29 pub time_range: Option<TimeRange>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TimeRange {
36 pub start: Option<DateTime<Utc>>,
37 pub end: Option<DateTime<Utc>>,
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct SearchPreferences {
43 pub result_limit: u32,
45 pub semantic_enabled: bool,
47 pub include_archived: bool,
49 pub highlight_matches: bool,
51 pub group_by_session: bool,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct QueryRefinement {
58 pub query: String,
60 pub refinement_type: RefinementType,
62 pub confidence: f64,
64 pub explanation: String,
66 pub expected_improvement: String,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72#[serde(rename_all = "snake_case")]
73pub enum RefinementType {
74 Specificity,
76 Broadening,
78 Correction,
80 Synonyms,
82 Contextual,
84 Temporal,
86 ProviderFilter,
88 SemanticExpansion,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct EnrichedSearchResult {
95 pub session_id: String,
97 pub title: String,
99 pub relevance: f64,
101 pub snippets: Vec<String>,
103 pub match_reason: String,
105 pub follow_ups: Vec<String>,
107}
108
109#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct SearchAnalytics {
112 pub total_searches: u64,
114 pub successful_searches: u64,
116 pub refinements_suggested: u64,
118 pub refinements_accepted: u64,
120 pub avg_relevance: f64,
122 pub common_patterns: HashMap<String, u32>,
124}
125
126pub struct SearchAgentState {
128 search_history: Vec<SearchHistoryEntry>,
130 analytics: SearchAnalytics,
132 patterns: Vec<QueryPattern>,
134 context_cache: HashMap<String, SearchContext>,
136}
137
138#[derive(Debug, Clone)]
140struct SearchHistoryEntry {
141 query: String,
142 timestamp: DateTime<Utc>,
143 result_count: u32,
144 refinements_used: Vec<String>,
145}
146
147#[derive(Debug, Clone)]
149struct QueryPattern {
150 pattern: String,
151 frequency: u32,
152 avg_results: f64,
153 best_refinements: Vec<String>,
154}
155
156pub struct SearchRefinementAgent {
158 config: AgentConfig,
160 state: Arc<RwLock<SearchAgentState>>,
162}
163
164impl SearchRefinementAgent {
165 pub fn new() -> Self {
167 let config = AgentConfig {
168 name: "search-refinement-agent".to_string(),
169 description: "Context-aware search query refinement".to_string(),
170 instruction: SEARCH_SYSTEM_PROMPT.to_string(),
171 ..Default::default()
172 };
173
174 let state = SearchAgentState {
175 search_history: Vec::new(),
176 analytics: SearchAnalytics::default(),
177 patterns: Vec::new(),
178 context_cache: HashMap::new(),
179 };
180
181 Self {
182 config,
183 state: Arc::new(RwLock::new(state)),
184 }
185 }
186
187 pub async fn refine_query(
189 &self,
190 query: &str,
191 context: Option<SearchContext>,
192 ) -> Vec<QueryRefinement> {
193 let mut refinements = Vec::new();
194 let query_lower = query.to_lowercase();
195
196 let corrections = self.check_spelling(query);
198 for correction in corrections {
199 refinements.push(QueryRefinement {
200 query: correction.clone(),
201 refinement_type: RefinementType::Correction,
202 confidence: 0.9,
203 explanation: "Corrected potential typo".to_string(),
204 expected_improvement: "More accurate results".to_string(),
205 });
206 }
207
208 let synonyms = self.find_synonyms(&query_lower);
210 for synonym in synonyms {
211 refinements.push(QueryRefinement {
212 query: format!("{} OR {}", query, synonym),
213 refinement_type: RefinementType::Synonyms,
214 confidence: 0.75,
215 explanation: format!("Added synonym: {}", synonym),
216 expected_improvement: "Broader coverage".to_string(),
217 });
218 }
219
220 if let Some(ctx) = context {
222 if !ctx.recent_queries.is_empty() {
224 let combined = format!("{} {}", query, ctx.recent_queries.last().unwrap());
225 refinements.push(QueryRefinement {
226 query: combined,
227 refinement_type: RefinementType::Contextual,
228 confidence: 0.7,
229 explanation: "Combined with recent search".to_string(),
230 expected_improvement: "More relevant to your current focus".to_string(),
231 });
232 }
233
234 if ctx.providers.len() == 1 {
236 refinements.push(QueryRefinement {
237 query: format!("{} provider:{}", query, ctx.providers[0]),
238 refinement_type: RefinementType::ProviderFilter,
239 confidence: 0.8,
240 explanation: format!("Filtered to {} sessions", ctx.providers[0]),
241 expected_improvement: "Focused on your active provider".to_string(),
242 });
243 }
244
245 refinements.push(QueryRefinement {
247 query: format!("{} after:7days", query),
248 refinement_type: RefinementType::Temporal,
249 confidence: 0.6,
250 explanation: "Limited to last 7 days".to_string(),
251 expected_improvement: "Recent and relevant results".to_string(),
252 });
253 }
254
255 if query.split_whitespace().count() < 3 {
257 let specific_suggestions = self.suggest_specific_terms(&query_lower).await;
258 for suggestion in specific_suggestions {
259 refinements.push(QueryRefinement {
260 query: format!("{} {}", query, suggestion),
261 refinement_type: RefinementType::Specificity,
262 confidence: 0.65,
263 explanation: format!("Added specific term: {}", suggestion),
264 expected_improvement: "More targeted results".to_string(),
265 });
266 }
267 }
268
269 if self.is_technical_query(&query_lower) {
271 refinements.push(QueryRefinement {
272 query: query.to_string(),
273 refinement_type: RefinementType::SemanticExpansion,
274 confidence: 0.85,
275 explanation: "Use semantic search for technical content".to_string(),
276 expected_improvement: "Find conceptually related discussions".to_string(),
277 });
278 }
279
280 refinements.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
282
283 {
285 let mut state = self.state.write().await;
286 state.analytics.refinements_suggested += refinements.len() as u64;
287 }
288
289 refinements
290 }
291
292 pub async fn record_search(
294 &self,
295 query: &str,
296 result_count: u32,
297 refinements_used: Vec<String>,
298 ) {
299 let mut state = self.state.write().await;
300
301 state.search_history.push(SearchHistoryEntry {
302 query: query.to_string(),
303 timestamp: Utc::now(),
304 result_count,
305 refinements_used: refinements_used.clone(),
306 });
307
308 let history_len = state.search_history.len();
310 if history_len > 1000 {
311 state.search_history.drain(0..history_len - 1000);
312 }
313
314 state.analytics.total_searches += 1;
316 if result_count > 0 {
317 state.analytics.successful_searches += 1;
318 }
319 if !refinements_used.is_empty() {
320 state.analytics.refinements_accepted += 1;
321 }
322
323 let pattern = self.extract_pattern(query);
325 if let Some(existing) = state.patterns.iter_mut().find(|p| p.pattern == pattern) {
326 existing.frequency += 1;
327 existing.avg_results = (existing.avg_results * (existing.frequency - 1) as f64
328 + result_count as f64)
329 / existing.frequency as f64;
330 } else {
331 state.patterns.push(QueryPattern {
332 pattern,
333 frequency: 1,
334 avg_results: result_count as f64,
335 best_refinements: refinements_used,
336 });
337 }
338 }
339
340 pub async fn get_analytics(&self) -> SearchAnalytics {
342 let state = self.state.read().await;
343 state.analytics.clone()
344 }
345
346 pub async fn suggest_follow_ups(&self, _session_id: &str, query: &str) -> Vec<String> {
348 let mut suggestions = Vec::new();
349
350 suggestions.push(format!("{} example", query));
352 suggestions.push(format!("{} solution", query));
353 suggestions.push(format!("related to {}", query));
354
355 suggestions
359 }
360
361 fn check_spelling(&self, query: &str) -> Vec<String> {
363 let mut corrections = Vec::new();
364
365 let corrections_map: HashMap<&str, &str> = [
367 ("javascrip", "javascript"),
368 ("pytohn", "python"),
369 ("typescrip", "typescript"),
370 ("fucntion", "function"),
371 ("aync", "async"),
372 ("awiat", "await"),
373 ("improt", "import"),
374 ("exprot", "export"),
375 ("cosnt", "const"),
376 ("retrun", "return"),
377 ]
378 .iter()
379 .cloned()
380 .collect();
381
382 let _words: Vec<&str> = query.split_whitespace().collect();
383 for (typo, correct) in &corrections_map {
384 if query.to_lowercase().contains(typo) {
385 let corrected = query.to_lowercase().replace(typo, correct);
386 corrections.push(corrected);
387 }
388 }
389
390 corrections
391 }
392
393 fn find_synonyms(&self, query: &str) -> Vec<String> {
395 let mut synonyms = Vec::new();
396
397 let synonym_map: HashMap<&str, Vec<&str>> = [
398 ("error", vec!["exception", "bug", "issue", "problem"]),
399 ("function", vec!["method", "procedure", "routine"]),
400 ("variable", vec!["var", "const", "let", "parameter"]),
401 ("create", vec!["make", "generate", "build", "new"]),
402 ("delete", vec!["remove", "destroy", "drop"]),
403 ("find", vec!["search", "locate", "query", "get"]),
404 ("update", vec!["modify", "change", "edit", "patch"]),
405 ("api", vec!["endpoint", "route", "service"]),
406 ("database", vec!["db", "storage", "repository"]),
407 ]
408 .iter()
409 .cloned()
410 .collect();
411
412 for (term, syns) in &synonym_map {
413 if query.contains(term) {
414 for syn in syns {
415 synonyms.push(syn.to_string());
416 }
417 }
418 }
419
420 synonyms.truncate(3); synonyms
422 }
423
424 async fn suggest_specific_terms(&self, query: &str) -> Vec<String> {
426 let mut suggestions = Vec::new();
427
428 if query.contains("error") || query.contains("bug") {
430 suggestions.push("fix".to_string());
431 suggestions.push("solution".to_string());
432 }
433 if query.contains("how") {
434 suggestions.push("step-by-step".to_string());
435 suggestions.push("example".to_string());
436 }
437 if query.contains("best") {
438 suggestions.push("practice".to_string());
439 suggestions.push("approach".to_string());
440 }
441
442 suggestions.truncate(2);
443 suggestions
444 }
445
446 fn is_technical_query(&self, query: &str) -> bool {
448 let technical_terms = [
449 "function",
450 "class",
451 "method",
452 "api",
453 "error",
454 "bug",
455 "code",
456 "implement",
457 "debug",
458 "async",
459 "await",
460 "promise",
461 "callback",
462 "component",
463 "module",
464 "import",
465 "export",
466 "typescript",
467 "javascript",
468 "python",
469 "rust",
470 "react",
471 "vue",
472 "angular",
473 "node",
474 "sql",
475 ];
476
477 technical_terms.iter().any(|term| query.contains(term))
478 }
479
480 fn extract_pattern(&self, query: &str) -> String {
482 let words: Vec<&str> = query.split_whitespace().collect();
484 if words.len() <= 2 {
485 return query.to_lowercase();
486 }
487
488 words
490 .iter()
491 .map(|w| if w.len() > 5 { "[TERM]" } else { *w })
492 .collect::<Vec<_>>()
493 .join(" ")
494 }
495}
496
497impl Default for SearchRefinementAgent {
498 fn default() -> Self {
499 Self::new()
500 }
501}
502
503const SEARCH_SYSTEM_PROMPT: &str = r#"You are a context-aware search refinement agent for Chasm.
505
506Your role is to help users find relevant chat sessions by:
5071. Understanding the intent behind their search queries
5082. Suggesting refinements that will improve results
5093. Learning from search patterns to make better suggestions
5104. Providing contextual suggestions based on recent activity
511
512When refining a query, consider:
513- Is the query too broad or too specific?
514- Are there common synonyms or related terms?
515- Does the user's recent activity suggest a focus area?
516- Would time-based or provider-based filters help?
517
518Always explain why a refinement might help.
519"#;
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[tokio::test]
526 async fn test_search_agent_creation() {
527 let agent = SearchRefinementAgent::new();
528 let analytics = agent.get_analytics().await;
529 assert_eq!(analytics.total_searches, 0);
530 }
531
532 #[tokio::test]
533 async fn test_refine_query_basic() {
534 let agent = SearchRefinementAgent::new();
535 let refinements = agent.refine_query("python error", None).await;
536 assert!(!refinements.is_empty());
537 }
538
539 #[tokio::test]
540 async fn test_refine_query_with_context() {
541 let agent = SearchRefinementAgent::new();
542 let context = SearchContext {
543 recent_queries: vec!["async await".to_string()],
544 recent_sessions: vec![],
545 workspace_id: Some("test-workspace".to_string()),
546 providers: vec!["copilot".to_string()],
547 preferences: SearchPreferences::default(),
548 time_range: None,
549 };
550 let refinements = agent.refine_query("function", Some(context)).await;
551
552 let has_contextual = refinements
554 .iter()
555 .any(|r| r.refinement_type == RefinementType::Contextual);
556 assert!(has_contextual || !refinements.is_empty());
557 }
558
559 #[tokio::test]
560 async fn test_spelling_correction() {
561 let agent = SearchRefinementAgent::new();
562 let refinements = agent.refine_query("pytohn function", None).await;
563
564 let has_correction = refinements
565 .iter()
566 .any(|r| r.refinement_type == RefinementType::Correction);
567 assert!(has_correction);
568 }
569
570 #[tokio::test]
571 async fn test_record_search() {
572 let agent = SearchRefinementAgent::new();
573 agent.record_search("test query", 10, vec![]).await;
574
575 let analytics = agent.get_analytics().await;
576 assert_eq!(analytics.total_searches, 1);
577 assert_eq!(analytics.successful_searches, 1);
578 }
579}