1use anyhow::Result;
7use serde_json::json;
8
9use crate::rag::{llm::LlmClient, SmartSearchConfig};
10
11#[derive(Debug, Clone)]
13pub struct EnhancedQuery {
14 pub original: String,
15 pub variations: Vec<QueryVariation>,
16 pub detected_intent: QueryIntent,
17 #[allow(dead_code)]
18 pub suggested_terms: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct QueryVariation {
24 pub query: String,
25 pub strategy: SearchStrategy,
26 pub weight: f32,
27}
28
29#[derive(Debug, Clone)]
31pub enum SearchStrategy {
32 Semantic, Keyword, #[allow(dead_code)]
35 Fuzzy, Code, Mixed, }
39
40#[derive(Debug, Clone)]
42pub enum QueryIntent {
43 CodeSearch {
44 language: Option<String>,
45 component_type: Option<String>, },
47 Documentation,
48 Configuration,
49 TechnicalConcept,
50 Debugging,
51 #[allow(dead_code)]
52 Unknown,
53}
54
55pub struct QueryEnhancer {
57 llm_client: Option<LlmClient>,
58 config: SmartSearchConfig,
59}
60
61impl QueryEnhancer {
62 pub fn new(llm_client: Option<LlmClient>, config: SmartSearchConfig) -> Self {
64 Self { llm_client, config }
65 }
66
67 pub async fn enhance_query(&self, query: &str) -> Result<EnhancedQuery> {
69 log::debug!("Enhancing query: '{}'", query);
70
71 let detected_intent = self.detect_query_intent(query).await?;
73
74 let mut variations = Vec::new();
75 let mut suggested_terms = Vec::new();
76
77 if let Some(ref llm_client) = self.llm_client {
79 if self.config.enable_query_enhancement {
80 match self
81 .enhance_with_llm(query, &detected_intent, llm_client)
82 .await
83 {
84 Ok((llm_variations, llm_terms)) => {
85 variations.extend(llm_variations);
86 suggested_terms.extend(llm_terms);
87 log::debug!(
88 "LLM enhancement succeeded with {} variations",
89 variations.len()
90 );
91 }
92 Err(e) => {
93 log::warn!("LLM enhancement failed, using fallback: {}", e);
94 }
95 }
96 }
97 }
98
99 let fallback_variations = self.enhance_with_fallback(query, &detected_intent);
101 variations.extend(fallback_variations);
102
103 variations.insert(
105 0,
106 QueryVariation {
107 query: query.to_string(),
108 strategy: SearchStrategy::Mixed,
109 weight: 1.0,
110 },
111 );
112
113 variations.truncate(self.config.max_query_variations.max(1));
115
116 Ok(EnhancedQuery {
117 original: query.to_string(),
118 variations,
119 detected_intent,
120 suggested_terms,
121 })
122 }
123
124 async fn detect_query_intent(&self, query: &str) -> Result<QueryIntent> {
126 let query_lower = query.to_lowercase();
127
128 if self.is_code_query(&query_lower) {
130 let language = self.detect_programming_language(&query_lower);
131 let component_type = self.detect_component_type(&query_lower);
132
133 return Ok(QueryIntent::CodeSearch {
134 language,
135 component_type,
136 });
137 }
138
139 if query_lower.contains("config")
141 || query_lower.contains("settings")
142 || query_lower.contains("environment")
143 {
144 return Ok(QueryIntent::Configuration);
145 }
146
147 if query_lower.contains("error")
149 || query_lower.contains("bug")
150 || query_lower.contains("debug")
151 || query_lower.contains("issue")
152 || query_lower.contains("problem")
153 {
154 return Ok(QueryIntent::Debugging);
155 }
156
157 if query_lower.contains("how to")
159 || query_lower.contains("guide")
160 || query_lower.contains("tutorial")
161 || query_lower.contains("example")
162 {
163 return Ok(QueryIntent::Documentation);
164 }
165
166 Ok(QueryIntent::TechnicalConcept)
167 }
168
169 fn is_code_query(&self, query: &str) -> bool {
171 let code_indicators = [
172 "function",
173 "method",
174 "class",
175 "struct",
176 "interface",
177 "variable",
178 "implementation",
179 "where is",
180 "how does",
181 "used",
182 "called",
183 "middleware",
184 "authentication",
185 "validation",
186 "security",
187 "database",
188 "connection",
189 "handler",
190 "controller",
191 "service",
192 "component",
193 "module",
194 "library",
195 "package",
196 "import",
197 ];
198
199 code_indicators
200 .iter()
201 .any(|&indicator| query.contains(indicator))
202 }
203
204 fn detect_programming_language(&self, query: &str) -> Option<String> {
206 let language_keywords = [
207 (
208 "rust",
209 vec!["fn", "impl", "struct", "trait", "cargo", "rust"],
210 ),
211 (
212 "javascript",
213 vec![
214 "function", "const", "let", "var", "nodejs", "js", "react", "vue",
215 ],
216 ),
217 ("typescript", vec!["interface", "type", "typescript", "ts"]),
218 (
219 "python",
220 vec!["def", "class", "import", "python", "django", "flask"],
221 ),
222 ("java", vec!["public", "private", "class", "java", "spring"]),
223 ("go", vec!["func", "package", "golang", "go"]),
224 ("c++", vec!["class", "namespace", "cpp", "c++"]),
225 ("c", vec!["struct", "typedef", "c programming"]),
226 ];
227
228 for (lang, keywords) in &language_keywords {
229 if keywords.iter().any(|&keyword| query.contains(keyword)) {
230 return Some(lang.to_string());
231 }
232 }
233
234 None
235 }
236
237 fn detect_component_type(&self, query: &str) -> Option<String> {
239 if query.contains("function") || query.contains("method") || query.contains("fn") {
240 Some("function".to_string())
241 } else if query.contains("class") || query.contains("struct") {
242 Some("class".to_string())
243 } else if query.contains("interface") || query.contains("trait") {
244 Some("interface".to_string())
245 } else if query.contains("variable") || query.contains("constant") {
246 Some("variable".to_string())
247 } else if query.contains("middleware") || query.contains("handler") {
248 Some("middleware".to_string())
249 } else {
250 None
251 }
252 }
253
254 async fn enhance_with_llm(
256 &self,
257 query: &str,
258 intent: &QueryIntent,
259 llm_client: &LlmClient,
260 ) -> Result<(Vec<QueryVariation>, Vec<String>)> {
261 let system_prompt = self.build_enhancement_prompt(intent);
262
263 let user_message = format!(
264 "Original query: \"{}\"\n\nPlease provide:\n1. 2-3 alternative ways to phrase this query for better search results\n2. Important keywords and synonyms\n3. Focus on {} context\n\nRespond in JSON format with 'variations' array and 'keywords' array.",
265 query,
266 match intent {
267 QueryIntent::CodeSearch { .. } => "code search and programming",
268 QueryIntent::Documentation => "documentation and guides",
269 QueryIntent::Configuration => "configuration and settings",
270 QueryIntent::Debugging => "troubleshooting and debugging",
271 QueryIntent::TechnicalConcept => "technical concepts",
272 QueryIntent::Unknown => "general search",
273 }
274 );
275
276 let response = self
278 .call_llm_for_enhancement(llm_client, &system_prompt, &user_message)
279 .await?;
280
281 let parsed_response: serde_json::Value = serde_json::from_str(&response)
282 .unwrap_or_else(|_| json!({"variations": [], "keywords": []}));
283
284 let variations = parsed_response["variations"]
285 .as_array()
286 .unwrap_or(&vec![])
287 .iter()
288 .filter_map(|v| v.as_str())
289 .map(|v| QueryVariation {
290 query: v.to_string(),
291 strategy: SearchStrategy::Semantic,
292 weight: 0.8,
293 })
294 .collect();
295
296 let keywords = parsed_response["keywords"]
297 .as_array()
298 .unwrap_or(&vec![])
299 .iter()
300 .filter_map(|k| k.as_str())
301 .map(|k| k.to_string())
302 .collect();
303
304 Ok((variations, keywords))
305 }
306
307 fn build_enhancement_prompt(&self, intent: &QueryIntent) -> String {
309 match intent {
310 QueryIntent::CodeSearch { language, component_type } => {
311 format!(
312 "You are a code search expert. Help enhance queries for finding {} {} in codebases. Focus on programming patterns, function names, and implementation details.",
313 component_type.as_deref().unwrap_or("code"),
314 language.as_deref().unwrap_or("programming")
315 )
316 },
317 QueryIntent::Documentation => {
318 "You are a documentation search expert. Help enhance queries for finding guides, tutorials, and explanations. Focus on learning objectives and procedural knowledge.".to_string()
319 },
320 QueryIntent::Configuration => {
321 "You are a configuration expert. Help enhance queries for finding settings, environment variables, and configuration patterns.".to_string()
322 },
323 QueryIntent::Debugging => {
324 "You are a debugging expert. Help enhance queries for finding error solutions, troubleshooting guides, and problem resolution.".to_string()
325 },
326 _ => {
327 "You are a technical search expert. Help enhance queries for better search results in technical documentation and code.".to_string()
328 }
329 }
330 }
331
332 async fn call_llm_for_enhancement(
334 &self,
335 _llm_client: &LlmClient,
336 _system_prompt: &str,
337 _user_message: &str,
338 ) -> Result<String> {
339 Ok(json!({
342 "variations": [],
343 "keywords": []
344 })
345 .to_string())
346 }
347
348 fn enhance_with_fallback(&self, query: &str, intent: &QueryIntent) -> Vec<QueryVariation> {
350 let mut variations = Vec::new();
351
352 match intent {
353 QueryIntent::CodeSearch {
354 language,
355 component_type,
356 } => {
357 if let Some(comp_type) = component_type {
359 variations.push(QueryVariation {
360 query: format!("{} {}", comp_type, query),
361 strategy: SearchStrategy::Code,
362 weight: 0.9,
363 });
364 }
365
366 if let Some(lang) = language {
367 variations.push(QueryVariation {
368 query: format!("{} {}", lang, query),
369 strategy: SearchStrategy::Code,
370 weight: 0.8,
371 });
372 }
373
374 if query.contains("where") {
376 let without_where = query.replace("where is", "").replace("where", "");
377 let trimmed = without_where.trim();
378 variations.push(QueryVariation {
379 query: format!("{} implementation", trimmed),
380 strategy: SearchStrategy::Semantic,
381 weight: 0.7,
382 });
383 }
384 }
385 QueryIntent::Documentation => {
386 variations.push(QueryVariation {
388 query: format!("how to {}", query),
389 strategy: SearchStrategy::Semantic,
390 weight: 0.7,
391 });
392 variations.push(QueryVariation {
393 query: format!("{} guide", query),
394 strategy: SearchStrategy::Keyword,
395 weight: 0.6,
396 });
397 }
398 _ => {
399 variations.push(QueryVariation {
401 query: query.to_string(),
402 strategy: SearchStrategy::Keyword,
403 weight: 0.6,
404 });
405 }
406 }
407
408 variations
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[tokio::test]
417 async fn test_query_intent_detection() {
418 let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
419
420 let result = enhancer
421 .detect_query_intent("where is middleware being used?")
422 .await
423 .unwrap();
424 matches!(result, QueryIntent::CodeSearch { .. });
425
426 let result = enhancer
427 .detect_query_intent("how to configure authentication")
428 .await
429 .unwrap();
430 matches!(result, QueryIntent::Configuration);
431 }
432
433 #[tokio::test]
434 async fn test_fallback_enhancement() {
435 let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
436
437 let result = enhancer
438 .enhance_query("validate_code_security function")
439 .await
440 .unwrap();
441 assert!(result.variations.len() > 1);
442 assert_eq!(result.original, "validate_code_security function");
443 }
444}