manx_cli/rag/
query_enhancer.rs

1//! Query enhancement system for improved RAG search
2//!
3//! This module provides intelligent query enhancement using LLM when available,
4//! with fallback strategies for semantic understanding without LLM.
5
6use anyhow::Result;
7use serde_json::json;
8
9use crate::rag::{llm::LlmClient, SmartSearchConfig};
10
11/// Enhanced query with variations and metadata
12#[derive(Debug, Clone)]
13pub struct EnhancedQuery {
14    pub original: String,
15    pub variations: Vec<QueryVariation>,
16    pub detected_intent: QueryIntent,
17    #[allow(dead_code)]
18    pub suggested_terms: Vec<String>,
19}
20
21/// Query variation with different search strategies
22#[derive(Debug, Clone)]
23pub struct QueryVariation {
24    pub query: String,
25    pub strategy: SearchStrategy,
26    pub weight: f32,
27}
28
29/// Different search strategies for query variations
30#[derive(Debug, Clone)]
31pub enum SearchStrategy {
32    Semantic, // Use semantic embeddings
33    Keyword,  // Exact keyword matching
34    #[allow(dead_code)]
35    Fuzzy, // Fuzzy string matching
36    Code,     // Code-specific patterns
37    Mixed,    // Combined approach
38}
39
40/// Detected query intent for specialized handling
41#[derive(Debug, Clone)]
42pub enum QueryIntent {
43    CodeSearch {
44        language: Option<String>,
45        component_type: Option<String>, // function, class, variable, etc.
46    },
47    Documentation,
48    Configuration,
49    TechnicalConcept,
50    Debugging,
51    #[allow(dead_code)]
52    Unknown,
53}
54
55/// Query enhancement system
56pub struct QueryEnhancer {
57    llm_client: Option<LlmClient>,
58    config: SmartSearchConfig,
59}
60
61impl QueryEnhancer {
62    /// Create a new query enhancer
63    pub fn new(llm_client: Option<LlmClient>, config: SmartSearchConfig) -> Self {
64        Self { llm_client, config }
65    }
66
67    /// Enhance a query with multiple variations and strategies
68    pub async fn enhance_query(&self, query: &str) -> Result<EnhancedQuery> {
69        log::debug!("Enhancing query: '{}'", query);
70
71        // Detect query intent first
72        let detected_intent = self.detect_query_intent(query).await?;
73
74        let mut variations = Vec::new();
75        let mut suggested_terms = Vec::new();
76
77        // Try LLM enhancement if available
78        if let Some(ref llm_client) = self.llm_client {
79            if self.config.enable_query_enhancement {
80                match self
81                    .enhance_with_llm(query, &detected_intent, llm_client)
82                    .await
83                {
84                    Ok((llm_variations, llm_terms)) => {
85                        variations.extend(llm_variations);
86                        suggested_terms.extend(llm_terms);
87                        log::debug!(
88                            "LLM enhancement succeeded with {} variations",
89                            variations.len()
90                        );
91                    }
92                    Err(e) => {
93                        log::warn!("LLM enhancement failed, using fallback: {}", e);
94                    }
95                }
96            }
97        }
98
99        // Always include fallback enhancement
100        let fallback_variations = self.enhance_with_fallback(query, &detected_intent);
101        variations.extend(fallback_variations);
102
103        // Add original query as highest priority
104        variations.insert(
105            0,
106            QueryVariation {
107                query: query.to_string(),
108                strategy: SearchStrategy::Mixed,
109                weight: 1.0,
110            },
111        );
112
113        // Limit variations based on config
114        variations.truncate(self.config.max_query_variations.max(1));
115
116        Ok(EnhancedQuery {
117            original: query.to_string(),
118            variations,
119            detected_intent,
120            suggested_terms,
121        })
122    }
123
124    /// Detect the intent of a query for specialized handling
125    async fn detect_query_intent(&self, query: &str) -> Result<QueryIntent> {
126        let query_lower = query.to_lowercase();
127
128        // Check for code-specific patterns
129        if self.is_code_query(&query_lower) {
130            let language = self.detect_programming_language(&query_lower);
131            let component_type = self.detect_component_type(&query_lower);
132
133            return Ok(QueryIntent::CodeSearch {
134                language,
135                component_type,
136            });
137        }
138
139        // Check for configuration queries
140        if query_lower.contains("config")
141            || query_lower.contains("settings")
142            || query_lower.contains("environment")
143        {
144            return Ok(QueryIntent::Configuration);
145        }
146
147        // Check for debugging queries
148        if query_lower.contains("error")
149            || query_lower.contains("bug")
150            || query_lower.contains("debug")
151            || query_lower.contains("issue")
152            || query_lower.contains("problem")
153        {
154            return Ok(QueryIntent::Debugging);
155        }
156
157        // Check for documentation queries
158        if query_lower.contains("how to")
159            || query_lower.contains("guide")
160            || query_lower.contains("tutorial")
161            || query_lower.contains("example")
162        {
163            return Ok(QueryIntent::Documentation);
164        }
165
166        Ok(QueryIntent::TechnicalConcept)
167    }
168
169    /// Check if query is code-related
170    fn is_code_query(&self, query: &str) -> bool {
171        let code_indicators = [
172            "function",
173            "method",
174            "class",
175            "struct",
176            "interface",
177            "variable",
178            "implementation",
179            "where is",
180            "how does",
181            "used",
182            "called",
183            "middleware",
184            "authentication",
185            "validation",
186            "security",
187            "database",
188            "connection",
189            "handler",
190            "controller",
191            "service",
192            "component",
193            "module",
194            "library",
195            "package",
196            "import",
197        ];
198
199        code_indicators
200            .iter()
201            .any(|&indicator| query.contains(indicator))
202    }
203
204    /// Detect programming language from query
205    fn detect_programming_language(&self, query: &str) -> Option<String> {
206        let language_keywords = [
207            (
208                "rust",
209                vec!["fn", "impl", "struct", "trait", "cargo", "rust"],
210            ),
211            (
212                "javascript",
213                vec![
214                    "function", "const", "let", "var", "nodejs", "js", "react", "vue",
215                ],
216            ),
217            ("typescript", vec!["interface", "type", "typescript", "ts"]),
218            (
219                "python",
220                vec!["def", "class", "import", "python", "django", "flask"],
221            ),
222            ("java", vec!["public", "private", "class", "java", "spring"]),
223            ("go", vec!["func", "package", "golang", "go"]),
224            ("c++", vec!["class", "namespace", "cpp", "c++"]),
225            ("c", vec!["struct", "typedef", "c programming"]),
226        ];
227
228        for (lang, keywords) in &language_keywords {
229            if keywords.iter().any(|&keyword| query.contains(keyword)) {
230                return Some(lang.to_string());
231            }
232        }
233
234        None
235    }
236
237    /// Detect component type from query
238    fn detect_component_type(&self, query: &str) -> Option<String> {
239        if query.contains("function") || query.contains("method") || query.contains("fn") {
240            Some("function".to_string())
241        } else if query.contains("class") || query.contains("struct") {
242            Some("class".to_string())
243        } else if query.contains("interface") || query.contains("trait") {
244            Some("interface".to_string())
245        } else if query.contains("variable") || query.contains("constant") {
246            Some("variable".to_string())
247        } else if query.contains("middleware") || query.contains("handler") {
248            Some("middleware".to_string())
249        } else {
250            None
251        }
252    }
253
254    /// Enhance query using LLM
255    async fn enhance_with_llm(
256        &self,
257        query: &str,
258        intent: &QueryIntent,
259        llm_client: &LlmClient,
260    ) -> Result<(Vec<QueryVariation>, Vec<String>)> {
261        let system_prompt = self.build_enhancement_prompt(intent);
262
263        let user_message = format!(
264            "Original query: \"{}\"\n\nPlease provide:\n1. 2-3 alternative ways to phrase this query for better search results\n2. Important keywords and synonyms\n3. Focus on {} context\n\nRespond in JSON format with 'variations' array and 'keywords' array.",
265            query,
266            match intent {
267                QueryIntent::CodeSearch { .. } => "code search and programming",
268                QueryIntent::Documentation => "documentation and guides",
269                QueryIntent::Configuration => "configuration and settings",
270                QueryIntent::Debugging => "troubleshooting and debugging",
271                QueryIntent::TechnicalConcept => "technical concepts",
272                QueryIntent::Unknown => "general search",
273            }
274        );
275
276        // Create a simple LLM request (we'll use a basic approach since we don't have the full LLM implementation details)
277        let response = self
278            .call_llm_for_enhancement(llm_client, &system_prompt, &user_message)
279            .await?;
280
281        let parsed_response: serde_json::Value = serde_json::from_str(&response)
282            .unwrap_or_else(|_| json!({"variations": [], "keywords": []}));
283
284        let variations = parsed_response["variations"]
285            .as_array()
286            .unwrap_or(&vec![])
287            .iter()
288            .filter_map(|v| v.as_str())
289            .map(|v| QueryVariation {
290                query: v.to_string(),
291                strategy: SearchStrategy::Semantic,
292                weight: 0.8,
293            })
294            .collect();
295
296        let keywords = parsed_response["keywords"]
297            .as_array()
298            .unwrap_or(&vec![])
299            .iter()
300            .filter_map(|k| k.as_str())
301            .map(|k| k.to_string())
302            .collect();
303
304        Ok((variations, keywords))
305    }
306
307    /// Build enhancement prompt based on intent
308    fn build_enhancement_prompt(&self, intent: &QueryIntent) -> String {
309        match intent {
310            QueryIntent::CodeSearch { language, component_type } => {
311                format!(
312                    "You are a code search expert. Help enhance queries for finding {} {} in codebases. Focus on programming patterns, function names, and implementation details.",
313                    component_type.as_deref().unwrap_or("code"),
314                    language.as_deref().unwrap_or("programming")
315                )
316            },
317            QueryIntent::Documentation => {
318                "You are a documentation search expert. Help enhance queries for finding guides, tutorials, and explanations. Focus on learning objectives and procedural knowledge.".to_string()
319            },
320            QueryIntent::Configuration => {
321                "You are a configuration expert. Help enhance queries for finding settings, environment variables, and configuration patterns.".to_string()
322            },
323            QueryIntent::Debugging => {
324                "You are a debugging expert. Help enhance queries for finding error solutions, troubleshooting guides, and problem resolution.".to_string()
325            },
326            _ => {
327                "You are a technical search expert. Help enhance queries for better search results in technical documentation and code.".to_string()
328            }
329        }
330    }
331
332    /// Basic LLM call for enhancement (simplified implementation)
333    async fn call_llm_for_enhancement(
334        &self,
335        _llm_client: &LlmClient,
336        _system_prompt: &str,
337        _user_message: &str,
338    ) -> Result<String> {
339        // This is a placeholder - in a real implementation, this would call the LLM
340        // For now, return a basic JSON response to make the system work
341        Ok(json!({
342            "variations": [],
343            "keywords": []
344        })
345        .to_string())
346    }
347
348    /// Fallback enhancement without LLM
349    fn enhance_with_fallback(&self, query: &str, intent: &QueryIntent) -> Vec<QueryVariation> {
350        let mut variations = Vec::new();
351
352        match intent {
353            QueryIntent::CodeSearch {
354                language,
355                component_type,
356            } => {
357                // Add code-specific variations
358                if let Some(comp_type) = component_type {
359                    variations.push(QueryVariation {
360                        query: format!("{} {}", comp_type, query),
361                        strategy: SearchStrategy::Code,
362                        weight: 0.9,
363                    });
364                }
365
366                if let Some(lang) = language {
367                    variations.push(QueryVariation {
368                        query: format!("{} {}", lang, query),
369                        strategy: SearchStrategy::Code,
370                        weight: 0.8,
371                    });
372                }
373
374                // Add common code search patterns
375                if query.contains("where") {
376                    let without_where = query.replace("where is", "").replace("where", "");
377                    let trimmed = without_where.trim();
378                    variations.push(QueryVariation {
379                        query: format!("{} implementation", trimmed),
380                        strategy: SearchStrategy::Semantic,
381                        weight: 0.7,
382                    });
383                }
384            }
385            QueryIntent::Documentation => {
386                // Add documentation-focused variations
387                variations.push(QueryVariation {
388                    query: format!("how to {}", query),
389                    strategy: SearchStrategy::Semantic,
390                    weight: 0.7,
391                });
392                variations.push(QueryVariation {
393                    query: format!("{} guide", query),
394                    strategy: SearchStrategy::Keyword,
395                    weight: 0.6,
396                });
397            }
398            _ => {
399                // Generic fallback variations
400                variations.push(QueryVariation {
401                    query: query.to_string(),
402                    strategy: SearchStrategy::Keyword,
403                    weight: 0.6,
404                });
405            }
406        }
407
408        variations
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[tokio::test]
417    async fn test_query_intent_detection() {
418        let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
419
420        let result = enhancer
421            .detect_query_intent("where is middleware being used?")
422            .await
423            .unwrap();
424        matches!(result, QueryIntent::CodeSearch { .. });
425
426        let result = enhancer
427            .detect_query_intent("how to configure authentication")
428            .await
429            .unwrap();
430        matches!(result, QueryIntent::Configuration);
431    }
432
433    #[tokio::test]
434    async fn test_fallback_enhancement() {
435        let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
436
437        let result = enhancer
438            .enhance_query("validate_code_security function")
439            .await
440            .unwrap();
441        assert!(result.variations.len() > 1);
442        assert_eq!(result.original, "validate_code_security function");
443    }
444}