manx_cli/rag/
query_enhancer.rs

1//! Query enhancement system for improved RAG search
2//!
3//! This module provides intelligent query enhancement using LLM when available,
4//! with fallback strategies for semantic understanding without LLM.
5
6use anyhow::Result;
7use serde_json::json;
8use std::sync::Arc;
9
10use crate::rag::{llm::LlmClient, SmartSearchConfig};
11
12/// Enhanced query with variations and metadata
13#[derive(Debug, Clone)]
14pub struct EnhancedQuery {
15    pub original: String,
16    pub variations: Vec<QueryVariation>,
17    pub detected_intent: QueryIntent,
18    #[allow(dead_code)]
19    pub suggested_terms: Vec<String>,
20}
21
22/// Query variation with different search strategies
23#[derive(Debug, Clone)]
24pub struct QueryVariation {
25    pub query: String,
26    pub strategy: SearchStrategy,
27    pub weight: f32,
28}
29
30/// Different search strategies for query variations
31#[derive(Debug, Clone)]
32pub enum SearchStrategy {
33    Semantic, // Use semantic embeddings
34    Keyword,  // Exact keyword matching
35    #[allow(dead_code)]
36    Fuzzy, // Fuzzy string matching
37    Code,     // Code-specific patterns
38    Mixed,    // Combined approach
39}
40
41/// Detected query intent for specialized handling
42#[derive(Debug, Clone)]
43pub enum QueryIntent {
44    CodeSearch {
45        language: Option<String>,
46        component_type: Option<String>, // function, class, variable, etc.
47    },
48    Documentation,
49    Configuration,
50    TechnicalConcept,
51    Debugging,
52    #[allow(dead_code)]
53    Unknown,
54}
55
56/// Query enhancement system
57pub struct QueryEnhancer {
58    llm_client: Option<Arc<LlmClient>>,
59    config: SmartSearchConfig,
60}
61
62impl QueryEnhancer {
63    /// Create a new query enhancer
64    pub fn new(llm_client: Option<Arc<LlmClient>>, config: SmartSearchConfig) -> Self {
65        Self { llm_client, config }
66    }
67
68    /// Enhance a query with multiple variations and strategies
69    pub async fn enhance_query(&self, query: &str) -> Result<EnhancedQuery> {
70        log::debug!("Enhancing query: '{}'", query);
71
72        // Detect query intent first
73        let detected_intent = self.detect_query_intent(query).await?;
74
75        let mut variations = Vec::new();
76        let mut suggested_terms = Vec::new();
77
78        // Try LLM enhancement if available
79        if let Some(ref llm_client) = self.llm_client {
80            if self.config.enable_query_enhancement {
81                match self
82                    .enhance_with_llm(query, &detected_intent, llm_client)
83                    .await
84                {
85                    Ok((llm_variations, llm_terms)) => {
86                        variations.extend(llm_variations);
87                        suggested_terms.extend(llm_terms);
88                        log::debug!(
89                            "LLM enhancement succeeded with {} variations",
90                            variations.len()
91                        );
92                    }
93                    Err(e) => {
94                        log::warn!("LLM enhancement failed, using fallback: {}", e);
95                    }
96                }
97            }
98        }
99
100        // Always include fallback enhancement
101        let fallback_variations = self.enhance_with_fallback(query, &detected_intent);
102        variations.extend(fallback_variations);
103
104        // Add original query as highest priority
105        variations.insert(
106            0,
107            QueryVariation {
108                query: query.to_string(),
109                strategy: SearchStrategy::Mixed,
110                weight: 1.0,
111            },
112        );
113
114        // Limit variations based on config
115        variations.truncate(self.config.max_query_variations.max(1));
116
117        Ok(EnhancedQuery {
118            original: query.to_string(),
119            variations,
120            detected_intent,
121            suggested_terms,
122        })
123    }
124
125    /// Detect the intent of a query for specialized handling
126    async fn detect_query_intent(&self, query: &str) -> Result<QueryIntent> {
127        let query_lower = query.to_lowercase();
128
129        // Check for code-specific patterns
130        if self.is_code_query(&query_lower) {
131            let language = self.detect_programming_language(&query_lower);
132            let component_type = self.detect_component_type(&query_lower);
133
134            return Ok(QueryIntent::CodeSearch {
135                language,
136                component_type,
137            });
138        }
139
140        // Check for configuration queries
141        if query_lower.contains("config")
142            || query_lower.contains("settings")
143            || query_lower.contains("environment")
144        {
145            return Ok(QueryIntent::Configuration);
146        }
147
148        // Check for debugging queries
149        if query_lower.contains("error")
150            || query_lower.contains("bug")
151            || query_lower.contains("debug")
152            || query_lower.contains("issue")
153            || query_lower.contains("problem")
154        {
155            return Ok(QueryIntent::Debugging);
156        }
157
158        // Check for documentation queries
159        if query_lower.contains("how to")
160            || query_lower.contains("guide")
161            || query_lower.contains("tutorial")
162            || query_lower.contains("example")
163        {
164            return Ok(QueryIntent::Documentation);
165        }
166
167        Ok(QueryIntent::TechnicalConcept)
168    }
169
170    /// Check if query is code-related
171    fn is_code_query(&self, query: &str) -> bool {
172        let code_indicators = [
173            "function",
174            "method",
175            "class",
176            "struct",
177            "interface",
178            "variable",
179            "implementation",
180            "where is",
181            "how does",
182            "used",
183            "called",
184            "middleware",
185            "authentication",
186            "validation",
187            "security",
188            "database",
189            "connection",
190            "handler",
191            "controller",
192            "service",
193            "component",
194            "module",
195            "library",
196            "package",
197            "import",
198        ];
199
200        code_indicators
201            .iter()
202            .any(|&indicator| query.contains(indicator))
203    }
204
205    /// Detect programming language from query
206    fn detect_programming_language(&self, query: &str) -> Option<String> {
207        let language_keywords = [
208            (
209                "rust",
210                vec!["fn", "impl", "struct", "trait", "cargo", "rust"],
211            ),
212            (
213                "javascript",
214                vec![
215                    "function", "const", "let", "var", "nodejs", "js", "react", "vue",
216                ],
217            ),
218            ("typescript", vec!["interface", "type", "typescript", "ts"]),
219            (
220                "python",
221                vec!["def", "class", "import", "python", "django", "flask"],
222            ),
223            ("java", vec!["public", "private", "class", "java", "spring"]),
224            ("go", vec!["func", "package", "golang", "go"]),
225            ("c++", vec!["class", "namespace", "cpp", "c++"]),
226            ("c", vec!["struct", "typedef", "c programming"]),
227        ];
228
229        for (lang, keywords) in &language_keywords {
230            if keywords.iter().any(|&keyword| query.contains(keyword)) {
231                return Some(lang.to_string());
232            }
233        }
234
235        None
236    }
237
238    /// Detect component type from query
239    fn detect_component_type(&self, query: &str) -> Option<String> {
240        if query.contains("function") || query.contains("method") || query.contains("fn") {
241            Some("function".to_string())
242        } else if query.contains("class") || query.contains("struct") {
243            Some("class".to_string())
244        } else if query.contains("interface") || query.contains("trait") {
245            Some("interface".to_string())
246        } else if query.contains("variable") || query.contains("constant") {
247            Some("variable".to_string())
248        } else if query.contains("middleware") || query.contains("handler") {
249            Some("middleware".to_string())
250        } else {
251            None
252        }
253    }
254
255    /// Enhance query using LLM
256    async fn enhance_with_llm(
257        &self,
258        query: &str,
259        intent: &QueryIntent,
260        llm_client: &LlmClient,
261    ) -> Result<(Vec<QueryVariation>, Vec<String>)> {
262        let system_prompt = self.build_enhancement_prompt(intent);
263
264        let user_message = format!(
265            "Original query: \"{}\"\n\nPlease provide:\n1. 2-3 alternative ways to phrase this query for better search results\n2. Important keywords and synonyms\n3. Focus on {} context\n\nRespond in JSON format with 'variations' array and 'keywords' array.",
266            query,
267            match intent {
268                QueryIntent::CodeSearch { .. } => "code search and programming",
269                QueryIntent::Documentation => "documentation and guides",
270                QueryIntent::Configuration => "configuration and settings",
271                QueryIntent::Debugging => "troubleshooting and debugging",
272                QueryIntent::TechnicalConcept => "technical concepts",
273                QueryIntent::Unknown => "general search",
274            }
275        );
276
277        // Create a simple LLM request (we'll use a basic approach since we don't have the full LLM implementation details)
278        let response = self
279            .call_llm_for_enhancement(llm_client, &system_prompt, &user_message)
280            .await?;
281
282        let parsed_response: serde_json::Value = serde_json::from_str(&response)
283            .unwrap_or_else(|_| json!({"variations": [], "keywords": []}));
284
285        let variations = parsed_response["variations"]
286            .as_array()
287            .unwrap_or(&vec![])
288            .iter()
289            .filter_map(|v| v.as_str())
290            .map(|v| QueryVariation {
291                query: v.to_string(),
292                strategy: SearchStrategy::Semantic,
293                weight: 0.8,
294            })
295            .collect();
296
297        let keywords = parsed_response["keywords"]
298            .as_array()
299            .unwrap_or(&vec![])
300            .iter()
301            .filter_map(|k| k.as_str())
302            .map(|k| k.to_string())
303            .collect();
304
305        Ok((variations, keywords))
306    }
307
308    /// Build enhancement prompt based on intent
309    fn build_enhancement_prompt(&self, intent: &QueryIntent) -> String {
310        match intent {
311            QueryIntent::CodeSearch { language, component_type } => {
312                format!(
313                    "You are a code search expert. Help enhance queries for finding {} {} in codebases. Focus on programming patterns, function names, and implementation details.",
314                    component_type.as_deref().unwrap_or("code"),
315                    language.as_deref().unwrap_or("programming")
316                )
317            },
318            QueryIntent::Documentation => {
319                "You are a documentation search expert. Help enhance queries for finding guides, tutorials, and explanations. Focus on learning objectives and procedural knowledge.".to_string()
320            },
321            QueryIntent::Configuration => {
322                "You are a configuration expert. Help enhance queries for finding settings, environment variables, and configuration patterns.".to_string()
323            },
324            QueryIntent::Debugging => {
325                "You are a debugging expert. Help enhance queries for finding error solutions, troubleshooting guides, and problem resolution.".to_string()
326            },
327            _ => {
328                "You are a technical search expert. Help enhance queries for better search results in technical documentation and code.".to_string()
329            }
330        }
331    }
332
333    /// Basic LLM call for enhancement (simplified implementation)
334    async fn call_llm_for_enhancement(
335        &self,
336        _llm_client: &LlmClient,
337        _system_prompt: &str,
338        _user_message: &str,
339    ) -> Result<String> {
340        // This is a placeholder - in a real implementation, this would call the LLM
341        // For now, return a basic JSON response to make the system work
342        Ok(json!({
343            "variations": [],
344            "keywords": []
345        })
346        .to_string())
347    }
348
349    /// Fallback enhancement without LLM
350    fn enhance_with_fallback(&self, query: &str, intent: &QueryIntent) -> Vec<QueryVariation> {
351        let mut variations = Vec::new();
352
353        match intent {
354            QueryIntent::CodeSearch {
355                language,
356                component_type,
357            } => {
358                // Add code-specific variations
359                if let Some(comp_type) = component_type {
360                    variations.push(QueryVariation {
361                        query: format!("{} {}", comp_type, query),
362                        strategy: SearchStrategy::Code,
363                        weight: 0.9,
364                    });
365                }
366
367                if let Some(lang) = language {
368                    variations.push(QueryVariation {
369                        query: format!("{} {}", lang, query),
370                        strategy: SearchStrategy::Code,
371                        weight: 0.8,
372                    });
373                }
374
375                // Add common code search patterns
376                if query.contains("where") {
377                    let without_where = query.replace("where is", "").replace("where", "");
378                    let trimmed = without_where.trim();
379                    variations.push(QueryVariation {
380                        query: format!("{} implementation", trimmed),
381                        strategy: SearchStrategy::Semantic,
382                        weight: 0.7,
383                    });
384                }
385            }
386            QueryIntent::Documentation => {
387                // Add documentation-focused variations
388                variations.push(QueryVariation {
389                    query: format!("how to {}", query),
390                    strategy: SearchStrategy::Semantic,
391                    weight: 0.7,
392                });
393                variations.push(QueryVariation {
394                    query: format!("{} guide", query),
395                    strategy: SearchStrategy::Keyword,
396                    weight: 0.6,
397                });
398            }
399            _ => {
400                // Generic fallback variations
401                variations.push(QueryVariation {
402                    query: query.to_string(),
403                    strategy: SearchStrategy::Keyword,
404                    weight: 0.6,
405                });
406            }
407        }
408
409        variations
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[tokio::test]
418    async fn test_query_intent_detection() {
419        let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
420
421        let result = enhancer
422            .detect_query_intent("where is middleware being used?")
423            .await
424            .unwrap();
425        matches!(result, QueryIntent::CodeSearch { .. });
426
427        let result = enhancer
428            .detect_query_intent("how to configure authentication")
429            .await
430            .unwrap();
431        matches!(result, QueryIntent::Configuration);
432    }
433
434    #[tokio::test]
435    async fn test_fallback_enhancement() {
436        let enhancer = QueryEnhancer::new(None, SmartSearchConfig::default());
437
438        let result = enhancer
439            .enhance_query("validate_code_security function")
440            .await
441            .unwrap();
442        assert!(result.variations.len() > 1);
443        assert_eq!(result.original, "validate_code_security function");
444    }
445}