use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::agency::{Agent, AgentBuilder, AgentConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchContext {
pub recent_queries: Vec<String>,
pub recent_sessions: Vec<String>,
pub workspace_id: Option<String>,
pub providers: Vec<String>,
pub preferences: SearchPreferences,
pub time_range: Option<TimeRange>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeRange {
pub start: Option<DateTime<Utc>>,
pub end: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SearchPreferences {
pub result_limit: u32,
pub semantic_enabled: bool,
pub include_archived: bool,
pub highlight_matches: bool,
pub group_by_session: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRefinement {
pub query: String,
pub refinement_type: RefinementType,
pub confidence: f64,
pub explanation: String,
pub expected_improvement: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RefinementType {
Specificity,
Broadening,
Correction,
Synonyms,
Contextual,
Temporal,
ProviderFilter,
SemanticExpansion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnrichedSearchResult {
pub session_id: String,
pub title: String,
pub relevance: f64,
pub snippets: Vec<String>,
pub match_reason: String,
pub follow_ups: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SearchAnalytics {
pub total_searches: u64,
pub successful_searches: u64,
pub refinements_suggested: u64,
pub refinements_accepted: u64,
pub avg_relevance: f64,
pub common_patterns: HashMap<String, u32>,
}
pub struct SearchAgentState {
search_history: Vec<SearchHistoryEntry>,
analytics: SearchAnalytics,
patterns: Vec<QueryPattern>,
context_cache: HashMap<String, SearchContext>,
}
#[derive(Debug, Clone)]
struct SearchHistoryEntry {
query: String,
timestamp: DateTime<Utc>,
result_count: u32,
refinements_used: Vec<String>,
}
#[derive(Debug, Clone)]
struct QueryPattern {
pattern: String,
frequency: u32,
avg_results: f64,
best_refinements: Vec<String>,
}
pub struct SearchRefinementAgent {
config: AgentConfig,
state: Arc<RwLock<SearchAgentState>>,
}
impl SearchRefinementAgent {
pub fn new() -> Self {
let config = AgentConfig {
name: "search-refinement-agent".to_string(),
description: "Context-aware search query refinement".to_string(),
instruction: SEARCH_SYSTEM_PROMPT.to_string(),
..Default::default()
};
let state = SearchAgentState {
search_history: Vec::new(),
analytics: SearchAnalytics::default(),
patterns: Vec::new(),
context_cache: HashMap::new(),
};
Self {
config,
state: Arc::new(RwLock::new(state)),
}
}
pub async fn refine_query(
&self,
query: &str,
context: Option<SearchContext>,
) -> Vec<QueryRefinement> {
let mut refinements = Vec::new();
let query_lower = query.to_lowercase();
let corrections = self.check_spelling(query);
for correction in corrections {
refinements.push(QueryRefinement {
query: correction.clone(),
refinement_type: RefinementType::Correction,
confidence: 0.9,
explanation: "Corrected potential typo".to_string(),
expected_improvement: "More accurate results".to_string(),
});
}
let synonyms = self.find_synonyms(&query_lower);
for synonym in synonyms {
refinements.push(QueryRefinement {
query: format!("{} OR {}", query, synonym),
refinement_type: RefinementType::Synonyms,
confidence: 0.75,
explanation: format!("Added synonym: {}", synonym),
expected_improvement: "Broader coverage".to_string(),
});
}
if let Some(ctx) = context {
if !ctx.recent_queries.is_empty() {
let combined = format!("{} {}", query, ctx.recent_queries.last().unwrap());
refinements.push(QueryRefinement {
query: combined,
refinement_type: RefinementType::Contextual,
confidence: 0.7,
explanation: "Combined with recent search".to_string(),
expected_improvement: "More relevant to your current focus".to_string(),
});
}
if ctx.providers.len() == 1 {
refinements.push(QueryRefinement {
query: format!("{} provider:{}", query, ctx.providers[0]),
refinement_type: RefinementType::ProviderFilter,
confidence: 0.8,
explanation: format!("Filtered to {} sessions", ctx.providers[0]),
expected_improvement: "Focused on your active provider".to_string(),
});
}
refinements.push(QueryRefinement {
query: format!("{} after:7days", query),
refinement_type: RefinementType::Temporal,
confidence: 0.6,
explanation: "Limited to last 7 days".to_string(),
expected_improvement: "Recent and relevant results".to_string(),
});
}
if query.split_whitespace().count() < 3 {
let specific_suggestions = self.suggest_specific_terms(&query_lower).await;
for suggestion in specific_suggestions {
refinements.push(QueryRefinement {
query: format!("{} {}", query, suggestion),
refinement_type: RefinementType::Specificity,
confidence: 0.65,
explanation: format!("Added specific term: {}", suggestion),
expected_improvement: "More targeted results".to_string(),
});
}
}
if self.is_technical_query(&query_lower) {
refinements.push(QueryRefinement {
query: query.to_string(),
refinement_type: RefinementType::SemanticExpansion,
confidence: 0.85,
explanation: "Use semantic search for technical content".to_string(),
expected_improvement: "Find conceptually related discussions".to_string(),
});
}
refinements.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
{
let mut state = self.state.write().await;
state.analytics.refinements_suggested += refinements.len() as u64;
}
refinements
}
pub async fn record_search(
&self,
query: &str,
result_count: u32,
refinements_used: Vec<String>,
) {
let mut state = self.state.write().await;
state.search_history.push(SearchHistoryEntry {
query: query.to_string(),
timestamp: Utc::now(),
result_count,
refinements_used: refinements_used.clone(),
});
let history_len = state.search_history.len();
if history_len > 1000 {
state.search_history.drain(0..history_len - 1000);
}
state.analytics.total_searches += 1;
if result_count > 0 {
state.analytics.successful_searches += 1;
}
if !refinements_used.is_empty() {
state.analytics.refinements_accepted += 1;
}
let pattern = self.extract_pattern(query);
if let Some(existing) = state.patterns.iter_mut().find(|p| p.pattern == pattern) {
existing.frequency += 1;
existing.avg_results = (existing.avg_results * (existing.frequency - 1) as f64
+ result_count as f64)
/ existing.frequency as f64;
} else {
state.patterns.push(QueryPattern {
pattern,
frequency: 1,
avg_results: result_count as f64,
best_refinements: refinements_used,
});
}
}
pub async fn get_analytics(&self) -> SearchAnalytics {
let state = self.state.read().await;
state.analytics.clone()
}
pub async fn suggest_follow_ups(&self, _session_id: &str, query: &str) -> Vec<String> {
let mut suggestions = Vec::new();
suggestions.push(format!("{} example", query));
suggestions.push(format!("{} solution", query));
suggestions.push(format!("related to {}", query));
suggestions
}
fn check_spelling(&self, query: &str) -> Vec<String> {
let mut corrections = Vec::new();
let corrections_map: HashMap<&str, &str> = [
("javascrip", "javascript"),
("pytohn", "python"),
("typescrip", "typescript"),
("fucntion", "function"),
("aync", "async"),
("awiat", "await"),
("improt", "import"),
("exprot", "export"),
("cosnt", "const"),
("retrun", "return"),
]
.iter()
.cloned()
.collect();
let _words: Vec<&str> = query.split_whitespace().collect();
for (typo, correct) in &corrections_map {
if query.to_lowercase().contains(typo) {
let corrected = query.to_lowercase().replace(typo, correct);
corrections.push(corrected);
}
}
corrections
}
fn find_synonyms(&self, query: &str) -> Vec<String> {
let mut synonyms = Vec::new();
let synonym_map: HashMap<&str, Vec<&str>> = [
("error", vec!["exception", "bug", "issue", "problem"]),
("function", vec!["method", "procedure", "routine"]),
("variable", vec!["var", "const", "let", "parameter"]),
("create", vec!["make", "generate", "build", "new"]),
("delete", vec!["remove", "destroy", "drop"]),
("find", vec!["search", "locate", "query", "get"]),
("update", vec!["modify", "change", "edit", "patch"]),
("api", vec!["endpoint", "route", "service"]),
("database", vec!["db", "storage", "repository"]),
]
.iter()
.cloned()
.collect();
for (term, syns) in &synonym_map {
if query.contains(term) {
for syn in syns {
synonyms.push(syn.to_string());
}
}
}
synonyms.truncate(3); synonyms
}
async fn suggest_specific_terms(&self, query: &str) -> Vec<String> {
let mut suggestions = Vec::new();
if query.contains("error") || query.contains("bug") {
suggestions.push("fix".to_string());
suggestions.push("solution".to_string());
}
if query.contains("how") {
suggestions.push("step-by-step".to_string());
suggestions.push("example".to_string());
}
if query.contains("best") {
suggestions.push("practice".to_string());
suggestions.push("approach".to_string());
}
suggestions.truncate(2);
suggestions
}
fn is_technical_query(&self, query: &str) -> bool {
let technical_terms = [
"function",
"class",
"method",
"api",
"error",
"bug",
"code",
"implement",
"debug",
"async",
"await",
"promise",
"callback",
"component",
"module",
"import",
"export",
"typescript",
"javascript",
"python",
"rust",
"react",
"vue",
"angular",
"node",
"sql",
];
technical_terms.iter().any(|term| query.contains(term))
}
fn extract_pattern(&self, query: &str) -> String {
let words: Vec<&str> = query.split_whitespace().collect();
if words.len() <= 2 {
return query.to_lowercase();
}
words
.iter()
.map(|w| if w.len() > 5 { "[TERM]" } else { *w })
.collect::<Vec<_>>()
.join(" ")
}
}
impl Default for SearchRefinementAgent {
fn default() -> Self {
Self::new()
}
}
const SEARCH_SYSTEM_PROMPT: &str = r#"You are a context-aware search refinement agent for Chasm.
Your role is to help users find relevant chat sessions by:
1. Understanding the intent behind their search queries
2. Suggesting refinements that will improve results
3. Learning from search patterns to make better suggestions
4. Providing contextual suggestions based on recent activity
When refining a query, consider:
- Is the query too broad or too specific?
- Are there common synonyms or related terms?
- Does the user's recent activity suggest a focus area?
- Would time-based or provider-based filters help?
Always explain why a refinement might help.
"#;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_search_agent_creation() {
let agent = SearchRefinementAgent::new();
let analytics = agent.get_analytics().await;
assert_eq!(analytics.total_searches, 0);
}
#[tokio::test]
async fn test_refine_query_basic() {
let agent = SearchRefinementAgent::new();
let refinements = agent.refine_query("python error", None).await;
assert!(!refinements.is_empty());
}
#[tokio::test]
async fn test_refine_query_with_context() {
let agent = SearchRefinementAgent::new();
let context = SearchContext {
recent_queries: vec!["async await".to_string()],
recent_sessions: vec![],
workspace_id: Some("test-workspace".to_string()),
providers: vec!["copilot".to_string()],
preferences: SearchPreferences::default(),
time_range: None,
};
let refinements = agent.refine_query("function", Some(context)).await;
let has_contextual = refinements
.iter()
.any(|r| r.refinement_type == RefinementType::Contextual);
assert!(has_contextual || !refinements.is_empty());
}
#[tokio::test]
async fn test_spelling_correction() {
let agent = SearchRefinementAgent::new();
let refinements = agent.refine_query("pytohn function", None).await;
let has_correction = refinements
.iter()
.any(|r| r.refinement_type == RefinementType::Correction);
assert!(has_correction);
}
#[tokio::test]
async fn test_record_search() {
let agent = SearchRefinementAgent::new();
agent.record_search("test query", 10, vec![]).await;
let analytics = agent.get_analytics().await;
assert_eq!(analytics.total_searches, 1);
assert_eq!(analytics.successful_searches, 1);
}
}