Skip to main content

chasm/agency/
search_refinement.rs

1// Copyright (c) 2024-2027 Nervosys LLC
2// SPDX-License-Identifier: AGPL-3.0-only
3//! Context-Aware Search Refinement Agent
4//!
5//! An AI agent that understands search context and suggests query refinements
6//! for better search results across chat sessions.
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::agency::{Agent, AgentBuilder, AgentConfig};
15
16/// Search context from user's recent activity
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SearchContext {
19    /// Recent search queries
20    pub recent_queries: Vec<String>,
21    /// Recently viewed sessions
22    pub recent_sessions: Vec<String>,
23    /// Active workspace
24    pub workspace_id: Option<String>,
25    /// Active providers
26    pub providers: Vec<String>,
27    /// User preferences
28    pub preferences: SearchPreferences,
29    /// Time range of interest
30    pub time_range: Option<TimeRange>,
31}
32
33/// Time range for search
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TimeRange {
36    pub start: Option<DateTime<Utc>>,
37    pub end: Option<DateTime<Utc>>,
38}
39
40/// User search preferences
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct SearchPreferences {
43    /// Preferred result count
44    pub result_limit: u32,
45    /// Semantic search enabled
46    pub semantic_enabled: bool,
47    /// Include archived sessions
48    pub include_archived: bool,
49    /// Highlight matches
50    pub highlight_matches: bool,
51    /// Group by session
52    pub group_by_session: bool,
53}
54
55/// Query refinement suggestion
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct QueryRefinement {
58    /// Refined query
59    pub query: String,
60    /// Refinement type
61    pub refinement_type: RefinementType,
62    /// Confidence score (0.0 - 1.0)
63    pub confidence: f64,
64    /// Explanation
65    pub explanation: String,
66    /// Expected result improvement
67    pub expected_improvement: String,
68}
69
70/// Type of query refinement
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72#[serde(rename_all = "snake_case")]
73pub enum RefinementType {
74    /// Add more specific terms
75    Specificity,
76    /// Broaden the search
77    Broadening,
78    /// Correct spelling/typos
79    Correction,
80    /// Add synonyms
81    Synonyms,
82    /// Add context from recent activity
83    Contextual,
84    /// Filter by time
85    Temporal,
86    /// Filter by provider
87    ProviderFilter,
88    /// Semantic expansion
89    SemanticExpansion,
90}
91
92/// Search result with analysis
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct EnrichedSearchResult {
95    /// Original result
96    pub session_id: String,
97    /// Title
98    pub title: String,
99    /// Relevance score
100    pub relevance: f64,
101    /// Matching snippets
102    pub snippets: Vec<String>,
103    /// Why this result matched
104    pub match_reason: String,
105    /// Suggested follow-up queries
106    pub follow_ups: Vec<String>,
107}
108
109/// Search analytics
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct SearchAnalytics {
112    /// Total searches
113    pub total_searches: u64,
114    /// Successful searches (found results)
115    pub successful_searches: u64,
116    /// Refinements suggested
117    pub refinements_suggested: u64,
118    /// Refinements accepted
119    pub refinements_accepted: u64,
120    /// Average result relevance
121    pub avg_relevance: f64,
122    /// Common query patterns
123    pub common_patterns: HashMap<String, u32>,
124}
125
126/// Context-aware search agent state
127pub struct SearchAgentState {
128    /// Recent search history
129    search_history: Vec<SearchHistoryEntry>,
130    /// Analytics
131    analytics: SearchAnalytics,
132    /// Query patterns learned
133    patterns: Vec<QueryPattern>,
134    /// Session context cache
135    context_cache: HashMap<String, SearchContext>,
136}
137
138/// Search history entry
139#[derive(Debug, Clone)]
140struct SearchHistoryEntry {
141    query: String,
142    timestamp: DateTime<Utc>,
143    result_count: u32,
144    refinements_used: Vec<String>,
145}
146
147/// Learned query pattern
148#[derive(Debug, Clone)]
149struct QueryPattern {
150    pattern: String,
151    frequency: u32,
152    avg_results: f64,
153    best_refinements: Vec<String>,
154}
155
156/// Context-aware search refinement agent
157pub struct SearchRefinementAgent {
158    /// Agent configuration
159    config: AgentConfig,
160    /// Agent state
161    state: Arc<RwLock<SearchAgentState>>,
162}
163
164impl SearchRefinementAgent {
165    /// Create a new search refinement agent
166    pub fn new() -> Self {
167        let config = AgentConfig {
168            name: "search-refinement-agent".to_string(),
169            description: "Context-aware search query refinement".to_string(),
170            instruction: SEARCH_SYSTEM_PROMPT.to_string(),
171            ..Default::default()
172        };
173
174        let state = SearchAgentState {
175            search_history: Vec::new(),
176            analytics: SearchAnalytics::default(),
177            patterns: Vec::new(),
178            context_cache: HashMap::new(),
179        };
180
181        Self {
182            config,
183            state: Arc::new(RwLock::new(state)),
184        }
185    }
186
187    /// Analyze a query and suggest refinements
188    pub async fn refine_query(
189        &self,
190        query: &str,
191        context: Option<SearchContext>,
192    ) -> Vec<QueryRefinement> {
193        let mut refinements = Vec::new();
194        let query_lower = query.to_lowercase();
195
196        // 1. Check for common spelling corrections
197        let corrections = self.check_spelling(query);
198        for correction in corrections {
199            refinements.push(QueryRefinement {
200                query: correction.clone(),
201                refinement_type: RefinementType::Correction,
202                confidence: 0.9,
203                explanation: "Corrected potential typo".to_string(),
204                expected_improvement: "More accurate results".to_string(),
205            });
206        }
207
208        // 2. Suggest synonyms/related terms
209        let synonyms = self.find_synonyms(&query_lower);
210        for synonym in synonyms {
211            refinements.push(QueryRefinement {
212                query: format!("{} OR {}", query, synonym),
213                refinement_type: RefinementType::Synonyms,
214                confidence: 0.75,
215                explanation: format!("Added synonym: {}", synonym),
216                expected_improvement: "Broader coverage".to_string(),
217            });
218        }
219
220        // 3. Add contextual refinements
221        if let Some(ctx) = context {
222            // Use recent queries to suggest combinations
223            if !ctx.recent_queries.is_empty() {
224                let combined = format!("{} {}", query, ctx.recent_queries.last().unwrap());
225                refinements.push(QueryRefinement {
226                    query: combined,
227                    refinement_type: RefinementType::Contextual,
228                    confidence: 0.7,
229                    explanation: "Combined with recent search".to_string(),
230                    expected_improvement: "More relevant to your current focus".to_string(),
231                });
232            }
233
234            // Add provider filter if context suggests it
235            if ctx.providers.len() == 1 {
236                refinements.push(QueryRefinement {
237                    query: format!("{} provider:{}", query, ctx.providers[0]),
238                    refinement_type: RefinementType::ProviderFilter,
239                    confidence: 0.8,
240                    explanation: format!("Filtered to {} sessions", ctx.providers[0]),
241                    expected_improvement: "Focused on your active provider".to_string(),
242                });
243            }
244
245            // Add time filter for recency
246            refinements.push(QueryRefinement {
247                query: format!("{} after:7days", query),
248                refinement_type: RefinementType::Temporal,
249                confidence: 0.6,
250                explanation: "Limited to last 7 days".to_string(),
251                expected_improvement: "Recent and relevant results".to_string(),
252            });
253        }
254
255        // 4. Suggest specificity improvements
256        if query.split_whitespace().count() < 3 {
257            let specific_suggestions = self.suggest_specific_terms(&query_lower).await;
258            for suggestion in specific_suggestions {
259                refinements.push(QueryRefinement {
260                    query: format!("{} {}", query, suggestion),
261                    refinement_type: RefinementType::Specificity,
262                    confidence: 0.65,
263                    explanation: format!("Added specific term: {}", suggestion),
264                    expected_improvement: "More targeted results".to_string(),
265                });
266            }
267        }
268
269        // 5. Semantic expansion for technical queries
270        if self.is_technical_query(&query_lower) {
271            refinements.push(QueryRefinement {
272                query: query.to_string(),
273                refinement_type: RefinementType::SemanticExpansion,
274                confidence: 0.85,
275                explanation: "Use semantic search for technical content".to_string(),
276                expected_improvement: "Find conceptually related discussions".to_string(),
277            });
278        }
279
280        // Sort by confidence
281        refinements.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
282
283        // Update analytics
284        {
285            let mut state = self.state.write().await;
286            state.analytics.refinements_suggested += refinements.len() as u64;
287        }
288
289        refinements
290    }
291
292    /// Record a search for learning
293    pub async fn record_search(
294        &self,
295        query: &str,
296        result_count: u32,
297        refinements_used: Vec<String>,
298    ) {
299        let mut state = self.state.write().await;
300
301        state.search_history.push(SearchHistoryEntry {
302            query: query.to_string(),
303            timestamp: Utc::now(),
304            result_count,
305            refinements_used: refinements_used.clone(),
306        });
307
308        // Keep only last 1000 searches
309        let history_len = state.search_history.len();
310        if history_len > 1000 {
311            state.search_history.drain(0..history_len - 1000);
312        }
313
314        // Update analytics
315        state.analytics.total_searches += 1;
316        if result_count > 0 {
317            state.analytics.successful_searches += 1;
318        }
319        if !refinements_used.is_empty() {
320            state.analytics.refinements_accepted += 1;
321        }
322
323        // Update patterns
324        let pattern = self.extract_pattern(query);
325        if let Some(existing) = state.patterns.iter_mut().find(|p| p.pattern == pattern) {
326            existing.frequency += 1;
327            existing.avg_results = (existing.avg_results * (existing.frequency - 1) as f64
328                + result_count as f64)
329                / existing.frequency as f64;
330        } else {
331            state.patterns.push(QueryPattern {
332                pattern,
333                frequency: 1,
334                avg_results: result_count as f64,
335                best_refinements: refinements_used,
336            });
337        }
338    }
339
340    /// Get search analytics
341    pub async fn get_analytics(&self) -> SearchAnalytics {
342        let state = self.state.read().await;
343        state.analytics.clone()
344    }
345
346    /// Suggest related searches based on a result
347    pub async fn suggest_follow_ups(&self, _session_id: &str, query: &str) -> Vec<String> {
348        let mut suggestions = Vec::new();
349
350        // Based on common follow-up patterns
351        suggestions.push(format!("{} example", query));
352        suggestions.push(format!("{} solution", query));
353        suggestions.push(format!("related to {}", query));
354
355        // Could use LLM to generate more contextual suggestions
356        // based on session content
357
358        suggestions
359    }
360
361    /// Check for common spelling mistakes
362    fn check_spelling(&self, query: &str) -> Vec<String> {
363        let mut corrections = Vec::new();
364
365        // Common programming term corrections
366        let corrections_map: HashMap<&str, &str> = [
367            ("javascrip", "javascript"),
368            ("pytohn", "python"),
369            ("typescrip", "typescript"),
370            ("fucntion", "function"),
371            ("aync", "async"),
372            ("awiat", "await"),
373            ("improt", "import"),
374            ("exprot", "export"),
375            ("cosnt", "const"),
376            ("retrun", "return"),
377        ]
378        .iter()
379        .cloned()
380        .collect();
381
382        let _words: Vec<&str> = query.split_whitespace().collect();
383        for (typo, correct) in &corrections_map {
384            if query.to_lowercase().contains(typo) {
385                let corrected = query.to_lowercase().replace(typo, correct);
386                corrections.push(corrected);
387            }
388        }
389
390        corrections
391    }
392
393    /// Find synonyms for terms
394    fn find_synonyms(&self, query: &str) -> Vec<String> {
395        let mut synonyms = Vec::new();
396
397        let synonym_map: HashMap<&str, Vec<&str>> = [
398            ("error", vec!["exception", "bug", "issue", "problem"]),
399            ("function", vec!["method", "procedure", "routine"]),
400            ("variable", vec!["var", "const", "let", "parameter"]),
401            ("create", vec!["make", "generate", "build", "new"]),
402            ("delete", vec!["remove", "destroy", "drop"]),
403            ("find", vec!["search", "locate", "query", "get"]),
404            ("update", vec!["modify", "change", "edit", "patch"]),
405            ("api", vec!["endpoint", "route", "service"]),
406            ("database", vec!["db", "storage", "repository"]),
407        ]
408        .iter()
409        .cloned()
410        .collect();
411
412        for (term, syns) in &synonym_map {
413            if query.contains(term) {
414                for syn in syns {
415                    synonyms.push(syn.to_string());
416                }
417            }
418        }
419
420        synonyms.truncate(3); // Limit suggestions
421        synonyms
422    }
423
424    /// Suggest more specific terms
425    async fn suggest_specific_terms(&self, query: &str) -> Vec<String> {
426        let mut suggestions = Vec::new();
427
428        // Common specificity additions based on query content
429        if query.contains("error") || query.contains("bug") {
430            suggestions.push("fix".to_string());
431            suggestions.push("solution".to_string());
432        }
433        if query.contains("how") {
434            suggestions.push("step-by-step".to_string());
435            suggestions.push("example".to_string());
436        }
437        if query.contains("best") {
438            suggestions.push("practice".to_string());
439            suggestions.push("approach".to_string());
440        }
441
442        suggestions.truncate(2);
443        suggestions
444    }
445
446    /// Check if query is technical
447    fn is_technical_query(&self, query: &str) -> bool {
448        let technical_terms = [
449            "function",
450            "class",
451            "method",
452            "api",
453            "error",
454            "bug",
455            "code",
456            "implement",
457            "debug",
458            "async",
459            "await",
460            "promise",
461            "callback",
462            "component",
463            "module",
464            "import",
465            "export",
466            "typescript",
467            "javascript",
468            "python",
469            "rust",
470            "react",
471            "vue",
472            "angular",
473            "node",
474            "sql",
475        ];
476
477        technical_terms.iter().any(|term| query.contains(term))
478    }
479
480    /// Extract pattern from query for learning
481    fn extract_pattern(&self, query: &str) -> String {
482        // Normalize query to pattern
483        let words: Vec<&str> = query.split_whitespace().collect();
484        if words.len() <= 2 {
485            return query.to_lowercase();
486        }
487
488        // Keep structure but replace specific terms with placeholders
489        words
490            .iter()
491            .map(|w| if w.len() > 5 { "[TERM]" } else { *w })
492            .collect::<Vec<_>>()
493            .join(" ")
494    }
495}
496
497impl Default for SearchRefinementAgent {
498    fn default() -> Self {
499        Self::new()
500    }
501}
502
503/// System prompt for the search refinement agent
504const SEARCH_SYSTEM_PROMPT: &str = r#"You are a context-aware search refinement agent for Chasm.
505
506Your role is to help users find relevant chat sessions by:
5071. Understanding the intent behind their search queries
5082. Suggesting refinements that will improve results
5093. Learning from search patterns to make better suggestions
5104. Providing contextual suggestions based on recent activity
511
512When refining a query, consider:
513- Is the query too broad or too specific?
514- Are there common synonyms or related terms?
515- Does the user's recent activity suggest a focus area?
516- Would time-based or provider-based filters help?
517
518Always explain why a refinement might help.
519"#;
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[tokio::test]
526    async fn test_search_agent_creation() {
527        let agent = SearchRefinementAgent::new();
528        let analytics = agent.get_analytics().await;
529        assert_eq!(analytics.total_searches, 0);
530    }
531
532    #[tokio::test]
533    async fn test_refine_query_basic() {
534        let agent = SearchRefinementAgent::new();
535        let refinements = agent.refine_query("python error", None).await;
536        assert!(!refinements.is_empty());
537    }
538
539    #[tokio::test]
540    async fn test_refine_query_with_context() {
541        let agent = SearchRefinementAgent::new();
542        let context = SearchContext {
543            recent_queries: vec!["async await".to_string()],
544            recent_sessions: vec![],
545            workspace_id: Some("test-workspace".to_string()),
546            providers: vec!["copilot".to_string()],
547            preferences: SearchPreferences::default(),
548            time_range: None,
549        };
550        let refinements = agent.refine_query("function", Some(context)).await;
551
552        // Should have contextual refinements
553        let has_contextual = refinements
554            .iter()
555            .any(|r| r.refinement_type == RefinementType::Contextual);
556        assert!(has_contextual || !refinements.is_empty());
557    }
558
559    #[tokio::test]
560    async fn test_spelling_correction() {
561        let agent = SearchRefinementAgent::new();
562        let refinements = agent.refine_query("pytohn function", None).await;
563
564        let has_correction = refinements
565            .iter()
566            .any(|r| r.refinement_type == RefinementType::Correction);
567        assert!(has_correction);
568    }
569
570    #[tokio::test]
571    async fn test_record_search() {
572        let agent = SearchRefinementAgent::new();
573        agent.record_search("test query", 10, vec![]).await;
574
575        let analytics = agent.get_analytics().await;
576        assert_eq!(analytics.total_searches, 1);
577        assert_eq!(analytics.successful_searches, 1);
578    }
579}