lethe_core_rust/
query_understanding.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use crate::error::Result;
4use regex::Regex;
5use std::sync::OnceLock;
6
7/// Pre-compiled regex patterns for query analysis
8struct QueryRegexes {
9    code_function_call: Regex,
10    code_method_access: Regex,
11    code_punctuation: Regex,
12    code_keywords: Regex,
13    complexity_complex: Regex,
14    complexity_simple: Regex,
15    year_pattern: Regex,
16    date_pattern: Regex,
17    month_pattern: Regex,
18}
19
20impl QueryRegexes {
21    fn new() -> Self {
22        Self {
23            code_function_call: Regex::new(r"\w+\(\)").unwrap(),
24            code_method_access: Regex::new(r"\w+\.\w+").unwrap(),
25            code_punctuation: Regex::new(r"[{}:;\[\]]").unwrap(),
26            code_keywords: Regex::new(r"(?i)\b(def|class|import|function|const|let|var)\b").unwrap(),
27            complexity_complex: Regex::new(r"(?i)\b(complex|advanced|sophisticated|intricate)\b").unwrap(),
28            complexity_simple: Regex::new(r"(?i)\b(simple|basic|easy|straightforward)\b").unwrap(),
29            year_pattern: Regex::new(r"\b\d{4}\b").unwrap(),
30            date_pattern: Regex::new(r"\b\d{1,2}/\d{1,2}/\d{4}\b").unwrap(),
31            month_pattern: Regex::new(r"(?i)\b(january|february|march|april|may|june|july|august|september|october|november|december)\b").unwrap(),
32        }
33    }
34}
35
36static QUERY_REGEXES: OnceLock<QueryRegexes> = OnceLock::new();
37
38fn get_query_regexes() -> &'static QueryRegexes {
39    QUERY_REGEXES.get_or_init(QueryRegexes::new)
40}
41
42/// Static classification patterns to replace hardcoded logic
43static QUERY_TYPE_PATTERNS: &[(QueryType, &[&str])] = &[
44    (QueryType::Definitional, &["what is", "define", "definition of", "meaning of"]),
45    (QueryType::Procedural, &["how to", "steps to", "process of", "method to"]),
46    (QueryType::Comparative, &["compare", "difference between", "vs", "versus", "better than"]),
47    (QueryType::Enumerative, &["list of", "examples of", "types of", "kinds of"]),
48    (QueryType::Analytical, &["why", "analyze", "explain", "reason"]),
49    (QueryType::Subjective, &["opinion", "think", "feel", "recommend", "suggest"]),
50];
51
52static QUERY_INTENT_PATTERNS: &[(QueryIntent, &[&str])] = &[
53    (QueryIntent::Debug, &["error", "debug", "fix", "problem", "issue", "bug"]),
54    (QueryIntent::Code, &["code", "implement", "function", "class", "method"]),
55    (QueryIntent::Compare, &["compare", "difference", "vs", "versus"]),
56    (QueryIntent::Guide, &["steps", "guide", "tutorial", "instructions"]),
57    (QueryIntent::Explain, &["explain", "understand", "what", "clarify"]),
58    (QueryIntent::Assist, &["help", "assist", "how to", "need"]),
59    (QueryIntent::Chat, &["hello", "hi", "thanks", "thank you"]),
60];
61
62static TECHNICAL_DOMAINS: &[(&str, &[&str])] = &[
63    ("programming", &[
64        "code", "function", "variable", "algorithm", "programming", "software",
65        "debug", "api", "library", "javascript", "python", "java", "rust", "typescript"
66    ]),
67    ("machine_learning", &[
68        "machine learning", "neural network", "model", "training", "dataset",
69        "prediction", "classification", "ai", "artificial intelligence"
70    ]),
71    ("web_development", &[
72        "html", "css", "javascript", "react", "vue", "angular",
73        "frontend", "backend", "web", "http", "api", "rest"
74    ]),
75    ("database", &[
76        "database", "sql", "query", "table", "index", "schema",
77        "postgres", "mysql", "mongodb", "nosql"
78    ]),
79];
80
81static QUESTION_WORDS: &[&str] = &[
82    "what", "how", "why", "when", "where", "who", "which", "whose",
83    "can", "could", "should", "would", "will", "do", "does", "did",
84    "is", "are", "was", "were", "have", "has", "had",
85];
86
87/// Query classification types
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub enum QueryType {
90    /// Simple factual question
91    Factual,
92    /// Complex analytical question requiring reasoning
93    Analytical,
94    /// Question asking for a comparison
95    Comparative,
96    /// Question asking for a list or enumeration
97    Enumerative,
98    /// Question asking for a definition
99    Definitional,
100    /// Question asking for procedural steps
101    Procedural,
102    /// Question asking for code or technical implementation
103    Technical,
104    /// Question asking for opinion or subjective analysis
105    Subjective,
106    /// General conversational query
107    Conversational,
108}
109
110/// Intent classification for the query
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112pub enum QueryIntent {
113    /// User wants to find specific information
114    Search,
115    /// User wants an explanation or understanding
116    Explain,
117    /// User wants help with a task
118    Assist,
119    /// User wants to compare options
120    Compare,
121    /// User wants step-by-step instructions
122    Guide,
123    /// User wants code or technical solution
124    Code,
125    /// User wants to troubleshoot an issue
126    Debug,
127    /// User is having a conversation
128    Chat,
129}
130
131/// Complexity level of the query
132#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
133pub enum QueryComplexity {
134    Simple,
135    Medium,
136    Complex,
137    VeryComplex,
138}
139
140/// Domain classification for the query
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct QueryDomain {
143    pub primary_domain: String,
144    pub secondary_domains: Vec<String>,
145    pub confidence: f32,
146}
147
148/// Extracted entities from the query
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct QueryEntity {
151    pub text: String,
152    pub entity_type: String,
153    pub start_pos: usize,
154    pub end_pos: usize,
155    pub confidence: f32,
156}
157
158/// Features extracted from the query
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct QueryFeatures {
161    pub word_count: usize,
162    pub sentence_count: usize,
163    pub question_words: Vec<String>,
164    pub technical_terms: Vec<String>,
165    pub has_code: bool,
166    pub has_numbers: bool,
167    pub has_dates: bool,
168    pub language: String,
169}
170
171/// Comprehensive query understanding result
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct QueryUnderstanding {
174    pub original_query: String,
175    pub query_type: QueryType,
176    pub intent: QueryIntent,
177    pub complexity: QueryComplexity,
178    pub domain: QueryDomain,
179    pub entities: Vec<QueryEntity>,
180    pub features: QueryFeatures,
181    pub keywords: Vec<String>,
182    pub confidence: f32,
183}
184
185/// Helper struct for analyzing query complexity metrics
186#[derive(Debug)]
187struct QueryComplexityMetrics {
188    word_count: usize,
189    sentence_count: usize,
190    has_technical_terms: bool,
191    has_multiple_questions: bool,
192}
193
194impl QueryComplexityMetrics {
195    fn analyze(query: &str) -> Self {
196        let word_count = query.split_whitespace().count();
197        let sentence_count = query.split('.').count();
198        let has_technical_terms = QueryUnderstandingService::has_technical_terms(query);
199        let has_multiple_questions = query.matches('?').count() > 1;
200        
201        Self {
202            word_count,
203            sentence_count,
204            has_technical_terms,
205            has_multiple_questions,
206        }
207    }
208}
209
210/// Query understanding service with optimized pattern matching
211pub struct QueryUnderstandingService {
212    // Using static data instead of instance data for better performance
213}
214
215impl QueryUnderstandingService {
216    pub fn new() -> Self {
217        Self {}
218    }
219
220    /// Analyze a query and return comprehensive understanding
221    pub fn understand_query(&self, query: &str) -> Result<QueryUnderstanding> {
222        let normalized_query = query.to_lowercase().trim().to_string();
223        
224        let query_type = self.classify_query_type(&normalized_query);
225        let intent = self.classify_intent(&normalized_query);
226        let complexity = self.classify_complexity(&normalized_query);
227        let domain = self.classify_domain(&normalized_query);
228        let entities = self.extract_entities(&normalized_query);
229        let features = self.extract_features(&normalized_query);
230        let keywords = self.extract_keywords(&normalized_query);
231        let confidence = self.calculate_confidence(&normalized_query, &query_type, &intent);
232
233        Ok(QueryUnderstanding {
234            original_query: query.to_string(),
235            query_type,
236            intent,
237            complexity,
238            domain,
239            entities,
240            features,
241            keywords,
242            confidence,
243        })
244    }
245
246    /// Classify the type of query
247    fn classify_query_type(&self, query: &str) -> QueryType {
248        // Check for definitional queries
249        if query.contains("what is") || query.contains("define") || query.contains("definition") {
250            return QueryType::Definitional;
251        }
252
253        // Check for procedural queries
254        if query.contains("how to") || query.contains("steps") || query.contains("process") {
255            return QueryType::Procedural;
256        }
257
258        // Check for comparative queries
259        if query.contains("compare") || query.contains("difference") || query.contains("vs") || 
260           query.contains("versus") || query.contains("better") {
261            return QueryType::Comparative;
262        }
263
264        // Check for enumerative queries
265        if query.contains("list") || query.contains("examples") || query.contains("types of") {
266            return QueryType::Enumerative;
267        }
268
269        // Check for technical queries
270        if self.has_code_patterns(query) || Self::has_technical_terms(query) {
271            return QueryType::Technical;
272        }
273
274        // Check for analytical queries
275        if query.contains("why") || query.contains("analyze") || query.contains("explain") {
276            return QueryType::Analytical;
277        }
278
279        // Check for subjective queries
280        if query.contains("opinion") || query.contains("think") || query.contains("feel") ||
281           query.contains("recommend") {
282            return QueryType::Subjective;
283        }
284
285        // Default to factual for simple questions
286        QueryType::Factual
287    }
288
289    /// Classify the intent of the query
290    fn classify_intent(&self, query: &str) -> QueryIntent {
291        // Check more specific intents first before general ones
292        if query.contains("error") || query.contains("debug") || query.contains("fix") ||
293           query.contains("problem") {
294            return QueryIntent::Debug;
295        }
296
297        if self.has_code_patterns(query) || query.contains("code") || query.contains("implement") {
298            return QueryIntent::Code;
299        }
300
301        if query.contains("compare") || query.contains("difference") || query.contains("vs") {
302            return QueryIntent::Compare;
303        }
304
305        if query.contains("steps") || query.contains("guide") || query.contains("tutorial") {
306            return QueryIntent::Guide;
307        }
308
309        if query.contains("explain") || query.contains("understand") || query.contains("what") {
310            return QueryIntent::Explain;
311        }
312
313        if query.contains("help") || query.contains("assist") || query.contains("how to") {
314            return QueryIntent::Assist;
315        }
316
317        if query.contains("hello") || query.contains("thanks") || query.len() < 20 {
318            return QueryIntent::Chat;
319        }
320
321        QueryIntent::Search
322    }
323
324    /// Classify the complexity of the query
325    fn classify_complexity(&self, query: &str) -> QueryComplexity {
326        let regexes = get_query_regexes();
327        
328        // Check against predefined complexity patterns
329        if regexes.complexity_complex.is_match(query) {
330            return QueryComplexity::Complex;
331        }
332        if regexes.complexity_simple.is_match(query) {
333            return QueryComplexity::Simple;
334        }
335
336        let word_count = query.split_whitespace().count();
337        let sentence_count = query.split('.').count();
338        let has_technical = Self::has_technical_terms(query);
339        let has_multiple_questions = query.matches('?').count() > 1;
340
341        match (word_count, sentence_count, has_technical, has_multiple_questions) {
342            (w, s, true, true) if w > 30 && s > 3 => QueryComplexity::VeryComplex,
343            (w, s, _, true) if w > 20 && s > 2 => QueryComplexity::Complex,
344            (w, _, true, _) if w > 15 => QueryComplexity::Complex,
345            (w, _, _, _) if w > 10 => QueryComplexity::Medium,
346            _ => QueryComplexity::Simple,
347        }
348    }
349
350    /// Classify the domain of the query
351    fn classify_domain(&self, query: &str) -> QueryDomain {
352        let mut domain_scores: HashMap<String, f32> = HashMap::new();
353
354        // Check each technical domain
355        for (domain, keywords) in TECHNICAL_DOMAINS {
356            let mut score = 0.0;
357            for keyword in *keywords {
358                if query.contains(keyword) {
359                    score += 1.0;
360                }
361            }
362            if score > 0.0 {
363                domain_scores.insert(domain.to_string(), score / keywords.len() as f32);
364            }
365        }
366
367        // Find the best matching domain
368        if let Some((primary_domain, confidence)) = domain_scores.iter()
369            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) {
370            
371            let mut secondary_domains: Vec<String> = domain_scores
372                .iter()
373                .filter(|(domain, score)| *domain != primary_domain && **score > 0.3)
374                .map(|(domain, _)| domain.clone())
375                .collect();
376            secondary_domains.sort();
377
378            QueryDomain {
379                primary_domain: primary_domain.clone(),
380                secondary_domains,
381                confidence: *confidence,
382            }
383        } else {
384            QueryDomain {
385                primary_domain: "general".to_string(),
386                secondary_domains: Vec::new(),
387                confidence: 0.5,
388            }
389        }
390    }
391
392    /// Extract named entities from the query
393    fn extract_entities(&self, query: &str) -> Vec<QueryEntity> {
394        let mut entities = Vec::new();
395
396        // Simple entity extraction patterns
397        let patterns = vec![
398            (r"\b\d{4}\b", "year"),
399            (r"\b\d+\.\d+\.\d+\b", "version"),
400            (r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", "proper_noun"),
401            (r"\b\w+\(\)", "function"),
402            (r"\b\w+\.\w+\b", "method_or_attribute"),
403        ];
404
405        for (pattern, entity_type) in patterns {
406            if let Ok(regex) = Regex::new(pattern) {
407                for mat in regex.find_iter(query) {
408                    entities.push(QueryEntity {
409                        text: mat.as_str().to_string(),
410                        entity_type: entity_type.to_string(),
411                        start_pos: mat.start(),
412                        end_pos: mat.end(),
413                        confidence: 0.8,
414                    });
415                }
416            }
417        }
418
419        entities
420    }
421
422    /// Extract features from the query
423    fn extract_features(&self, query: &str) -> QueryFeatures {
424        let words: Vec<&str> = query.split_whitespace().collect();
425        let sentences: Vec<&str> = query.split('.').collect();
426
427        let question_words = words
428            .iter()
429            .filter(|word| QUESTION_WORDS.contains(&word.to_lowercase().as_str()))
430            .map(|word| word.to_string())
431            .collect();
432
433        let technical_terms = self.extract_technical_terms(query);
434
435        QueryFeatures {
436            word_count: words.len(),
437            sentence_count: sentences.len(),
438            question_words,
439            technical_terms,
440            has_code: self.has_code_patterns(query),
441            has_numbers: query.chars().any(|c| c.is_ascii_digit()),
442            has_dates: self.has_date_patterns(query),
443            language: "en".to_string(), // Simple language detection
444        }
445    }
446
447    /// Extract keywords from the query
448    fn extract_keywords(&self, query: &str) -> Vec<String> {
449        let stop_words = vec![
450            "a", "an", "and", "are", "as", "at", "be", "by", "for", "from",
451            "has", "he", "in", "is", "it", "its", "of", "on", "that", "the",
452            "to", "was", "were", "will", "with", "the", "this", "but", "they",
453            "have", "had", "what", "said", "each", "which", "she", "do", "how",
454        ];
455
456        query
457            .split_whitespace()
458            .filter(|word| {
459                let word = word.to_lowercase();
460                word.len() > 2 && !stop_words.contains(&word.as_str())
461            })
462            .map(|word| word.to_lowercase())
463            .collect()
464    }
465
466    /// Calculate confidence in the query understanding
467    fn calculate_confidence(&self, query: &str, query_type: &QueryType, _intent: &QueryIntent) -> f32 {
468        let mut confidence: f32 = 0.5; // Base confidence
469
470        // Boost confidence for clear patterns
471        if self.has_clear_question_words(query) {
472            confidence += 0.2;
473        }
474
475        if Self::has_technical_terms(query) && matches!(query_type, QueryType::Technical) {
476            confidence += 0.2;
477        }
478
479        if query.ends_with('?') {
480            confidence += 0.1;
481        }
482
483        // Reduce confidence for very short or very long queries
484        let word_count = query.split_whitespace().count();
485        if word_count < 3 || word_count > 50 {
486            confidence -= 0.1;
487        }
488
489        confidence.min(1.0_f32).max(0.0_f32)
490    }
491
492
493
494    /// Check if query has code patterns
495    fn has_code_patterns(&self, query: &str) -> bool {
496        let regexes = get_query_regexes();
497        regexes.code_function_call.is_match(query) ||
498        regexes.code_method_access.is_match(query) ||
499        regexes.code_punctuation.is_match(query) ||
500        regexes.code_keywords.is_match(query)
501    }
502
503    /// Check if query has technical terms
504    fn has_technical_terms(query: &str) -> bool {
505        TECHNICAL_DOMAINS.iter().any(|(_, terms)| {
506            terms.iter().any(|term| query.contains(term))
507        })
508    }
509
510    /// Check if query has clear question words
511    fn has_clear_question_words(&self, query: &str) -> bool {
512        QUESTION_WORDS.iter().any(|word| query.contains(word))
513    }
514
515    /// Check if query has date patterns
516    fn has_date_patterns(&self, query: &str) -> bool {
517        let regexes = get_query_regexes();
518        regexes.year_pattern.is_match(query) ||
519        regexes.date_pattern.is_match(query) ||
520        regexes.month_pattern.is_match(query)
521    }
522
523    /// Extract technical terms from query
524    fn extract_technical_terms(&self, query: &str) -> Vec<String> {
525        let mut terms = Vec::new();
526
527        for (_, domain_terms) in TECHNICAL_DOMAINS {
528            for term in *domain_terms {
529                if query.contains(term) {
530                    terms.push(term.to_string());
531                }
532            }
533        }
534
535        terms
536    }
537}
538
539impl Default for QueryUnderstandingService {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_query_type_classification() {
551        let service = QueryUnderstandingService::new();
552
553        let understanding = service.understand_query("What is machine learning?").unwrap();
554        assert_eq!(understanding.query_type, QueryType::Definitional);
555
556        let understanding = service.understand_query("How to implement a neural network?").unwrap();
557        assert_eq!(understanding.query_type, QueryType::Procedural);
558
559        let understanding = service.understand_query("Compare React vs Vue").unwrap();
560        assert_eq!(understanding.query_type, QueryType::Comparative);
561    }
562
563    #[test]
564    fn test_intent_classification() {
565        let service = QueryUnderstandingService::new();
566
567        let understanding = service.understand_query("Explain how neural networks work").unwrap();
568        assert_eq!(understanding.intent, QueryIntent::Explain);
569
570        let understanding = service.understand_query("Help me debug this code").unwrap();
571        assert_eq!(understanding.intent, QueryIntent::Debug);
572
573        let understanding = service.understand_query("Show me the steps to install Python").unwrap();
574        assert_eq!(understanding.intent, QueryIntent::Guide);
575    }
576
577    #[test]
578    fn test_complexity_classification() {
579        let service = QueryUnderstandingService::new();
580
581        let understanding = service.understand_query("Hi").unwrap();
582        assert_eq!(understanding.complexity, QueryComplexity::Simple);
583
584        let understanding = service.understand_query("How do I implement a complex distributed system with microservices architecture?").unwrap();
585        assert_eq!(understanding.complexity, QueryComplexity::Complex);
586    }
587
588    #[test]
589    fn test_domain_classification() {
590        let service = QueryUnderstandingService::new();
591
592        let understanding = service.understand_query("How to train a machine learning model?").unwrap();
593        assert_eq!(understanding.domain.primary_domain, "machine_learning");
594
595        let understanding = service.understand_query("Write a JavaScript function").unwrap();
596        assert_eq!(understanding.domain.primary_domain, "programming");
597    }
598
599    #[test]
600    fn test_feature_extraction() {
601        let service = QueryUnderstandingService::new();
602
603        let understanding = service.understand_query("What is the function setTimeout() in JavaScript?").unwrap();
604        assert!(understanding.features.word_count > 0);
605        assert!(understanding.features.has_code);
606        assert!(!understanding.features.question_words.is_empty());
607    }
608
609    #[test]
610    fn test_keyword_extraction() {
611        let service = QueryUnderstandingService::new();
612
613        let understanding = service.understand_query("How to implement machine learning algorithms").unwrap();
614        assert!(understanding.keywords.contains(&"implement".to_string()));
615        assert!(understanding.keywords.contains(&"machine".to_string()));
616        assert!(understanding.keywords.contains(&"learning".to_string()));
617        assert!(understanding.keywords.contains(&"algorithms".to_string()));
618    }
619
620    #[test]
621    fn test_confidence_calculation() {
622        let service = QueryUnderstandingService::new();
623
624        let understanding = service.understand_query("What is machine learning?").unwrap();
625        assert!(understanding.confidence > 0.5);
626
627        let understanding = service.understand_query("a").unwrap();
628        assert!(understanding.confidence < 0.5);
629    }
630}