1use anyhow::Result;
7use serde_json::json;
8use std::sync::Arc;
9
10use crate::rag::{llm::LlmClient, SmartSearchConfig};
11
12#[derive(Debug, Clone)]
14pub struct EnhancedQuery {
15 pub original: String,
16 pub variations: Vec<QueryVariation>,
17 pub detected_intent: QueryIntent,
18 #[allow(dead_code)]
19 pub suggested_terms: Vec<String>,
20}
21
22#[derive(Debug, Clone)]
24pub struct QueryVariation {
25 pub query: String,
26 pub strategy: SearchStrategy,
27 pub weight: f32,
28}
29
30#[derive(Debug, Clone)]
32pub enum SearchStrategy {
33 Semantic, Keyword, #[allow(dead_code)]
36 Fuzzy, Code, Mixed, }
40
41#[derive(Debug, Clone)]
43pub enum QueryIntent {
44 CodeSearch {
45 language: Option<String>,
46 component_type: Option<String>, },
48 Documentation,
49 Configuration,
50 TechnicalConcept,
51 Debugging,
52 #[allow(dead_code)]
53 Unknown,
54}
55
56pub struct QueryEnhancer {
58 llm_client: Option<Arc<LlmClient>>,
59 config: SmartSearchConfig,
60}
61
62impl QueryEnhancer {
63 pub fn new(llm_client: Option<Arc<LlmClient>>, config: SmartSearchConfig) -> Self {
65 Self { llm_client, config }
66 }
67
68 pub async fn enhance_query(&self, query: &str) -> Result<EnhancedQuery> {
70 log::debug!("Enhancing query: '{}'", query);
71
72 let detected_intent = self.detect_query_intent(query).await?;
74
75 let mut variations = Vec::new();
76 let mut suggested_terms = Vec::new();
77
78 if let Some(ref llm_client) = self.llm_client {
80 if self.config.enable_query_enhancement {
81 match self
82 .enhance_with_llm(query, &detected_intent, llm_client)
83 .await
84 {
85 Ok((llm_variations, llm_terms)) => {
86 variations.extend(llm_variations);
87 suggested_terms.extend(llm_terms);
88 log::debug!(
89 "LLM enhancement succeeded with {} variations",
90 variations.len()
91 );
92 }
93 Err(e) => {
94 log::warn!("LLM enhancement failed, using fallback: {}", e);
95 }
96 }
97 }
98 }
99
100 let fallback_variations = self.enhance_with_fallback(query, &detected_intent);
102 variations.extend(fallback_variations);
103
104 variations.insert(
106 0,
107 QueryVariation {
108 query: query.to_string(),
109 strategy: SearchStrategy::Mixed,
110 weight: 1.0,
111 },
112 );
113
114 variations.truncate(self.config.max_query_variations.max(1));
116
117 Ok(EnhancedQuery {
118 original: query.to_string(),
119 variations,
120 detected_intent,
121 suggested_terms,
122 })
123 }
124
125 async fn detect_query_intent(&self, query: &str) -> Result<QueryIntent> {
127 let query_lower = query.to_lowercase();
128
129 if self.is_code_query(&query_lower) {
131 let language = self.detect_programming_language(&query_lower);
132 let component_type = self.detect_component_type(&query_lower);
133
134 return Ok(QueryIntent::CodeSearch {
135 language,
136 component_type,
137 });
138 }
139
140 if query_lower.contains("config")
142 || query_lower.contains("settings")
143 || query_lower.contains("environment")
144 {
145 return Ok(QueryIntent::Configuration);
146 }
147
148 if query_lower.contains("error")
150 || query_lower.contains("bug")
151 || query_lower.contains("debug")
152 || query_lower.contains("issue")
153 || query_lower.contains("problem")
154 {
155 return Ok(QueryIntent::Debugging);
156 }
157
158 if query_lower.contains("how to")
160 || query_lower.contains("guide")
161 || query_lower.contains("tutorial")
162 || query_lower.contains("example")
163 {
164 return Ok(QueryIntent::Documentation);
165 }
166
167 Ok(QueryIntent::TechnicalConcept)
168 }
169
170 fn is_code_query(&self, query: &str) -> bool {
172 let code_indicators = [
173 "function",
174 "method",
175 "class",
176 "struct",
177 "interface",
178 "variable",
179 "implementation",
180 "where is",
181 "how does",
182 "used",
183 "called",
184 "middleware",
185 "authentication",
186 "validation",
187 "security",
188 "database",
189 "connection",
190 "handler",
191 "controller",
192 "service",
193 "component",
194 "module",
195 "library",
196 "package",
197 "import",
198 ];
199
200 code_indicators
201 .iter()
202 .any(|&indicator| query.contains(indicator))
203 }
204
205 fn detect_programming_language(&self, query: &str) -> Option<String> {
207 let language_keywords = [
208 (
209 "rust",
210 vec!["fn", "impl", "struct", "trait", "cargo", "rust"],
211 ),
212 (
213 "javascript",
214 vec![
215 "function", "const", "let", "var", "nodejs", "js", "react", "vue",
216 ],
217 ),
218 ("typescript", vec!["interface", "type", "typescript", "ts"]),
219 (
220 "python",
221 vec!["def", "class", "import", "python", "django", "flask"],
222 ),
223 ("java", vec!["public", "private", "class", "java", "spring"]),
224 ("go", vec!["func", "package", "golang", "go"]),
225 ("c++", vec!["class", "namespace", "cpp", "c++"]),
226 ("c", vec!["struct", "typedef", "c programming"]),
227 ];
228
229 for (lang, keywords) in &language_keywords {
230 if keywords.iter().any(|&keyword| query.contains(keyword)) {
231 return Some(lang.to_string());
232 }
233 }
234
235 None
236 }
237
238 fn detect_component_type(&self, query: &str) -> Option<String> {
240 if query.contains("function") || query.contains("method") || query.contains("fn") {
241 Some("function".to_string())
242 } else if query.contains("class") || query.contains("struct") {
243 Some("class".to_string())
244 } else if query.contains("interface") || query.contains("trait") {
245 Some("interface".to_string())
246 } else if query.contains("variable") || query.contains("constant") {
247 Some("variable".to_string())
248 } else if query.contains("middleware") || query.contains("handler") {
249 Some("middleware".to_string())
250 } else {
251 None
252 }
253 }
254
255 async fn enhance_with_llm(
257 &self,
258 query: &str,
259 intent: &QueryIntent,
260 llm_client: &LlmClient,
261 ) -> Result<(Vec<QueryVariation>, Vec<String>)> {
262 let system_prompt = self.build_enhancement_prompt(intent);
263
264 let user_message = format!(
265 "Original query: \"{}\"\n\nPlease provide:\n1. 2-3 alternative ways to phrase this query for better search results\n2. Important keywords and synonyms\n3. Focus on {} context\n\nRespond in JSON format with 'variations' array and 'keywords' array.",
266 query,
267 match intent {
268 QueryIntent::CodeSearch { .. } => "code search and programming",
269 QueryIntent::Documentation => "documentation and guides",
270 QueryIntent::Configuration => "configuration and settings",
271 QueryIntent::Debugging => "troubleshooting and debugging",
272 QueryIntent::TechnicalConcept => "technical concepts",
273 QueryIntent::Unknown => "general search",
274 }
275 );
276
277 let response = self
279 .call_llm_for_enhancement(llm_client, &system_prompt, &user_message)
280 .await?;
281
282 let parsed_response: serde_json::Value = serde_json::from_str(&response)
283 .unwrap_or_else(|_| json!({"variations": [], "keywords": []}));
284
285 let variations = parsed_response["variations"]
286 .as_array()
287 .unwrap_or(&vec![])
288 .iter()
289 .filter_map(|v| v.as_str())
290 .map(|v| QueryVariation {
291 query: v.to_string(),
292 strategy: SearchStrategy::Semantic,
293 weight: 0.8,
294 })
295 .collect();
296
297 let keywords = parsed_response["keywords"]
298 .as_array()
299 .unwrap_or(&vec![])
300 .iter()
301 .filter_map(|k| k.as_str())
302 .map(|k| k.to_string())
303 .collect();
304
305 Ok((variations, keywords))
306 }
307
308 fn build_enhancement_prompt(&self, intent: &QueryIntent) -> String {
310 match intent {
311 QueryIntent::CodeSearch { language, component_type } => {
312 format!(
313 "You are a code search expert. Help enhance queries for finding {} {} in codebases. Focus on programming patterns, function names, and implementation details.",
314 component_type.as_deref().unwrap_or("code"),
315 language.as_deref().unwrap_or("programming")
316 )
317 },
318 QueryIntent::Documentation => {
319 "You are a documentation search expert. Help enhance queries for finding guides, tutorials, and explanations. Focus on learning objectives and procedural knowledge.".to_string()
320 },
321 QueryIntent::Configuration => {
322 "You are a configuration expert. Help enhance queries for finding settings, environment variables, and configuration patterns.".to_string()
323 },
324 QueryIntent::Debugging => {
325 "You are a debugging expert. Help enhance queries for finding error solutions, troubleshooting guides, and problem resolution.".to_string()
326 },
327 _ => {
328 "You are a technical search expert. Help enhance queries for better search results in technical documentation and code.".to_string()
329 }
330 }
331 }
332
333 async fn call_llm_for_enhancement(
335 &self,
336 _llm_client: &LlmClient,
337 _system_prompt: &str,
338 _user_message: &str,
339 ) -> Result<String> {
340 Ok(json!({
343 "variations": [],
344 "keywords": []
345 })
346 .to_string())
347 }
348
349 fn enhance_with_fallback(&self, query: &str, intent: &QueryIntent) -> Vec<QueryVariation> {
351 let mut variations = Vec::new();
352
353 match intent {
354 QueryIntent::CodeSearch {
355 language,
356 component_type,
357 } => {
358 if let Some(comp_type) = component_type {
360 variations.push(QueryVariation {
361 query: format!("{} {}", comp_type, query),
362 strategy: SearchStrategy::Code,
363 weight: 0.9,
364 });
365 }
366
367 if let Some(lang) = language {
368 variations.push(QueryVariation {
369 query: format!("{} {}", lang, query),
370 strategy: SearchStrategy::Code,
371 weight: 0.8,
372 });
373 }
374
375 if query.contains("where") {
377 let without_where = query.replace("where is", "").replace("where", "");
378 let trimmed = without_where.trim();
379 variations.push(QueryVariation {
380 query: format!("{} implementation", trimmed),
381 strategy: SearchStrategy::Semantic,
382 weight: 0.7,
383 });
384 }
385 }
386 QueryIntent::Documentation => {
387 variations.push(QueryVariation {
389 query: format!("how to {}", query),
390 strategy: SearchStrategy::Semantic,
391 weight: 0.7,
392 });
393 variations.push(QueryVariation {
394 query: format!("{} guide", query),
395 strategy: SearchStrategy::Keyword,
396 weight: 0.6,
397 });
398 }
399 _ => {
400 variations.push(QueryVariation {
402 query: query.to_string(),
403 strategy: SearchStrategy::Keyword,
404 weight: 0.6,
405 });
406 }
407 }
408
409 variations
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[tokio::test]
418 async fn test_query_intent_detection() {
419 let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
420
421 let result = enhancer
422 .detect_query_intent("where is middleware being used?")
423 .await
424 .unwrap();
425 matches!(result, QueryIntent::CodeSearch { .. });
426
427 let result = enhancer
428 .detect_query_intent("how to configure authentication")
429 .await
430 .unwrap();
431 matches!(result, QueryIntent::Configuration);
432 }
433
434 #[tokio::test]
435 async fn test_fallback_enhancement() {
436 let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
437
438 let result = enhancer
439 .enhance_query("validate_code_security function")
440 .await
441 .unwrap();
442 assert!(result.variations.len() > 1);
443 assert_eq!(result.original, "validate_code_security function");
444 }
445}