1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use crate::error::Result;
4use regex::Regex;
5use std::sync::OnceLock;
6
7struct QueryRegexes {
9 code_function_call: Regex,
10 code_method_access: Regex,
11 code_punctuation: Regex,
12 code_keywords: Regex,
13 complexity_complex: Regex,
14 complexity_simple: Regex,
15 year_pattern: Regex,
16 date_pattern: Regex,
17 month_pattern: Regex,
18}
19
20impl QueryRegexes {
21 fn new() -> Self {
22 Self {
23 code_function_call: Regex::new(r"\w+\(\)").unwrap(),
24 code_method_access: Regex::new(r"\w+\.\w+").unwrap(),
25 code_punctuation: Regex::new(r"[{}:;\[\]]").unwrap(),
26 code_keywords: Regex::new(r"(?i)\b(def|class|import|function|const|let|var)\b").unwrap(),
27 complexity_complex: Regex::new(r"(?i)\b(complex|advanced|sophisticated|intricate)\b").unwrap(),
28 complexity_simple: Regex::new(r"(?i)\b(simple|basic|easy|straightforward)\b").unwrap(),
29 year_pattern: Regex::new(r"\b\d{4}\b").unwrap(),
30 date_pattern: Regex::new(r"\b\d{1,2}/\d{1,2}/\d{4}\b").unwrap(),
31 month_pattern: Regex::new(r"(?i)\b(january|february|march|april|may|june|july|august|september|october|november|december)\b").unwrap(),
32 }
33 }
34}
35
36static QUERY_REGEXES: OnceLock<QueryRegexes> = OnceLock::new();
37
38fn get_query_regexes() -> &'static QueryRegexes {
39 QUERY_REGEXES.get_or_init(QueryRegexes::new)
40}
41
42static QUERY_TYPE_PATTERNS: &[(QueryType, &[&str])] = &[
44 (QueryType::Definitional, &["what is", "define", "definition of", "meaning of"]),
45 (QueryType::Procedural, &["how to", "steps to", "process of", "method to"]),
46 (QueryType::Comparative, &["compare", "difference between", "vs", "versus", "better than"]),
47 (QueryType::Enumerative, &["list of", "examples of", "types of", "kinds of"]),
48 (QueryType::Analytical, &["why", "analyze", "explain", "reason"]),
49 (QueryType::Subjective, &["opinion", "think", "feel", "recommend", "suggest"]),
50];
51
52static QUERY_INTENT_PATTERNS: &[(QueryIntent, &[&str])] = &[
53 (QueryIntent::Debug, &["error", "debug", "fix", "problem", "issue", "bug"]),
54 (QueryIntent::Code, &["code", "implement", "function", "class", "method"]),
55 (QueryIntent::Compare, &["compare", "difference", "vs", "versus"]),
56 (QueryIntent::Guide, &["steps", "guide", "tutorial", "instructions"]),
57 (QueryIntent::Explain, &["explain", "understand", "what", "clarify"]),
58 (QueryIntent::Assist, &["help", "assist", "how to", "need"]),
59 (QueryIntent::Chat, &["hello", "hi", "thanks", "thank you"]),
60];
61
62static TECHNICAL_DOMAINS: &[(&str, &[&str])] = &[
63 ("programming", &[
64 "code", "function", "variable", "algorithm", "programming", "software",
65 "debug", "api", "library", "javascript", "python", "java", "rust", "typescript"
66 ]),
67 ("machine_learning", &[
68 "machine learning", "neural network", "model", "training", "dataset",
69 "prediction", "classification", "ai", "artificial intelligence"
70 ]),
71 ("web_development", &[
72 "html", "css", "javascript", "react", "vue", "angular",
73 "frontend", "backend", "web", "http", "api", "rest"
74 ]),
75 ("database", &[
76 "database", "sql", "query", "table", "index", "schema",
77 "postgres", "mysql", "mongodb", "nosql"
78 ]),
79];
80
81static QUESTION_WORDS: &[&str] = &[
82 "what", "how", "why", "when", "where", "who", "which", "whose",
83 "can", "could", "should", "would", "will", "do", "does", "did",
84 "is", "are", "was", "were", "have", "has", "had",
85];
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub enum QueryType {
90 Factual,
92 Analytical,
94 Comparative,
96 Enumerative,
98 Definitional,
100 Procedural,
102 Technical,
104 Subjective,
106 Conversational,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112pub enum QueryIntent {
113 Search,
115 Explain,
117 Assist,
119 Compare,
121 Guide,
123 Code,
125 Debug,
127 Chat,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
133pub enum QueryComplexity {
134 Simple,
135 Medium,
136 Complex,
137 VeryComplex,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct QueryDomain {
143 pub primary_domain: String,
144 pub secondary_domains: Vec<String>,
145 pub confidence: f32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct QueryEntity {
151 pub text: String,
152 pub entity_type: String,
153 pub start_pos: usize,
154 pub end_pos: usize,
155 pub confidence: f32,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct QueryFeatures {
161 pub word_count: usize,
162 pub sentence_count: usize,
163 pub question_words: Vec<String>,
164 pub technical_terms: Vec<String>,
165 pub has_code: bool,
166 pub has_numbers: bool,
167 pub has_dates: bool,
168 pub language: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct QueryUnderstanding {
174 pub original_query: String,
175 pub query_type: QueryType,
176 pub intent: QueryIntent,
177 pub complexity: QueryComplexity,
178 pub domain: QueryDomain,
179 pub entities: Vec<QueryEntity>,
180 pub features: QueryFeatures,
181 pub keywords: Vec<String>,
182 pub confidence: f32,
183}
184
185#[derive(Debug)]
187struct QueryComplexityMetrics {
188 word_count: usize,
189 sentence_count: usize,
190 has_technical_terms: bool,
191 has_multiple_questions: bool,
192}
193
194impl QueryComplexityMetrics {
195 fn analyze(query: &str) -> Self {
196 let word_count = query.split_whitespace().count();
197 let sentence_count = query.split('.').count();
198 let has_technical_terms = QueryUnderstandingService::has_technical_terms(query);
199 let has_multiple_questions = query.matches('?').count() > 1;
200
201 Self {
202 word_count,
203 sentence_count,
204 has_technical_terms,
205 has_multiple_questions,
206 }
207 }
208}
209
210pub struct QueryUnderstandingService {
212 }
214
215impl QueryUnderstandingService {
216 pub fn new() -> Self {
217 Self {}
218 }
219
220 pub fn understand_query(&self, query: &str) -> Result<QueryUnderstanding> {
222 let normalized_query = query.to_lowercase().trim().to_string();
223
224 let query_type = self.classify_query_type(&normalized_query);
225 let intent = self.classify_intent(&normalized_query);
226 let complexity = self.classify_complexity(&normalized_query);
227 let domain = self.classify_domain(&normalized_query);
228 let entities = self.extract_entities(&normalized_query);
229 let features = self.extract_features(&normalized_query);
230 let keywords = self.extract_keywords(&normalized_query);
231 let confidence = self.calculate_confidence(&normalized_query, &query_type, &intent);
232
233 Ok(QueryUnderstanding {
234 original_query: query.to_string(),
235 query_type,
236 intent,
237 complexity,
238 domain,
239 entities,
240 features,
241 keywords,
242 confidence,
243 })
244 }
245
246 fn classify_query_type(&self, query: &str) -> QueryType {
248 if query.contains("what is") || query.contains("define") || query.contains("definition") {
250 return QueryType::Definitional;
251 }
252
253 if query.contains("how to") || query.contains("steps") || query.contains("process") {
255 return QueryType::Procedural;
256 }
257
258 if query.contains("compare") || query.contains("difference") || query.contains("vs") ||
260 query.contains("versus") || query.contains("better") {
261 return QueryType::Comparative;
262 }
263
264 if query.contains("list") || query.contains("examples") || query.contains("types of") {
266 return QueryType::Enumerative;
267 }
268
269 if self.has_code_patterns(query) || Self::has_technical_terms(query) {
271 return QueryType::Technical;
272 }
273
274 if query.contains("why") || query.contains("analyze") || query.contains("explain") {
276 return QueryType::Analytical;
277 }
278
279 if query.contains("opinion") || query.contains("think") || query.contains("feel") ||
281 query.contains("recommend") {
282 return QueryType::Subjective;
283 }
284
285 QueryType::Factual
287 }
288
289 fn classify_intent(&self, query: &str) -> QueryIntent {
291 if query.contains("error") || query.contains("debug") || query.contains("fix") ||
293 query.contains("problem") {
294 return QueryIntent::Debug;
295 }
296
297 if self.has_code_patterns(query) || query.contains("code") || query.contains("implement") {
298 return QueryIntent::Code;
299 }
300
301 if query.contains("compare") || query.contains("difference") || query.contains("vs") {
302 return QueryIntent::Compare;
303 }
304
305 if query.contains("steps") || query.contains("guide") || query.contains("tutorial") {
306 return QueryIntent::Guide;
307 }
308
309 if query.contains("explain") || query.contains("understand") || query.contains("what") {
310 return QueryIntent::Explain;
311 }
312
313 if query.contains("help") || query.contains("assist") || query.contains("how to") {
314 return QueryIntent::Assist;
315 }
316
317 if query.contains("hello") || query.contains("thanks") || query.len() < 20 {
318 return QueryIntent::Chat;
319 }
320
321 QueryIntent::Search
322 }
323
324 fn classify_complexity(&self, query: &str) -> QueryComplexity {
326 let regexes = get_query_regexes();
327
328 if regexes.complexity_complex.is_match(query) {
330 return QueryComplexity::Complex;
331 }
332 if regexes.complexity_simple.is_match(query) {
333 return QueryComplexity::Simple;
334 }
335
336 let word_count = query.split_whitespace().count();
337 let sentence_count = query.split('.').count();
338 let has_technical = Self::has_technical_terms(query);
339 let has_multiple_questions = query.matches('?').count() > 1;
340
341 match (word_count, sentence_count, has_technical, has_multiple_questions) {
342 (w, s, true, true) if w > 30 && s > 3 => QueryComplexity::VeryComplex,
343 (w, s, _, true) if w > 20 && s > 2 => QueryComplexity::Complex,
344 (w, _, true, _) if w > 15 => QueryComplexity::Complex,
345 (w, _, _, _) if w > 10 => QueryComplexity::Medium,
346 _ => QueryComplexity::Simple,
347 }
348 }
349
350 fn classify_domain(&self, query: &str) -> QueryDomain {
352 let mut domain_scores: HashMap<String, f32> = HashMap::new();
353
354 for (domain, keywords) in TECHNICAL_DOMAINS {
356 let mut score = 0.0;
357 for keyword in *keywords {
358 if query.contains(keyword) {
359 score += 1.0;
360 }
361 }
362 if score > 0.0 {
363 domain_scores.insert(domain.to_string(), score / keywords.len() as f32);
364 }
365 }
366
367 if let Some((primary_domain, confidence)) = domain_scores.iter()
369 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) {
370
371 let mut secondary_domains: Vec<String> = domain_scores
372 .iter()
373 .filter(|(domain, score)| *domain != primary_domain && **score > 0.3)
374 .map(|(domain, _)| domain.clone())
375 .collect();
376 secondary_domains.sort();
377
378 QueryDomain {
379 primary_domain: primary_domain.clone(),
380 secondary_domains,
381 confidence: *confidence,
382 }
383 } else {
384 QueryDomain {
385 primary_domain: "general".to_string(),
386 secondary_domains: Vec::new(),
387 confidence: 0.5,
388 }
389 }
390 }
391
392 fn extract_entities(&self, query: &str) -> Vec<QueryEntity> {
394 let mut entities = Vec::new();
395
396 let patterns = vec![
398 (r"\b\d{4}\b", "year"),
399 (r"\b\d+\.\d+\.\d+\b", "version"),
400 (r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", "proper_noun"),
401 (r"\b\w+\(\)", "function"),
402 (r"\b\w+\.\w+\b", "method_or_attribute"),
403 ];
404
405 for (pattern, entity_type) in patterns {
406 if let Ok(regex) = Regex::new(pattern) {
407 for mat in regex.find_iter(query) {
408 entities.push(QueryEntity {
409 text: mat.as_str().to_string(),
410 entity_type: entity_type.to_string(),
411 start_pos: mat.start(),
412 end_pos: mat.end(),
413 confidence: 0.8,
414 });
415 }
416 }
417 }
418
419 entities
420 }
421
422 fn extract_features(&self, query: &str) -> QueryFeatures {
424 let words: Vec<&str> = query.split_whitespace().collect();
425 let sentences: Vec<&str> = query.split('.').collect();
426
427 let question_words = words
428 .iter()
429 .filter(|word| QUESTION_WORDS.contains(&word.to_lowercase().as_str()))
430 .map(|word| word.to_string())
431 .collect();
432
433 let technical_terms = self.extract_technical_terms(query);
434
435 QueryFeatures {
436 word_count: words.len(),
437 sentence_count: sentences.len(),
438 question_words,
439 technical_terms,
440 has_code: self.has_code_patterns(query),
441 has_numbers: query.chars().any(|c| c.is_ascii_digit()),
442 has_dates: self.has_date_patterns(query),
443 language: "en".to_string(), }
445 }
446
447 fn extract_keywords(&self, query: &str) -> Vec<String> {
449 let stop_words = vec![
450 "a", "an", "and", "are", "as", "at", "be", "by", "for", "from",
451 "has", "he", "in", "is", "it", "its", "of", "on", "that", "the",
452 "to", "was", "were", "will", "with", "the", "this", "but", "they",
453 "have", "had", "what", "said", "each", "which", "she", "do", "how",
454 ];
455
456 query
457 .split_whitespace()
458 .filter(|word| {
459 let word = word.to_lowercase();
460 word.len() > 2 && !stop_words.contains(&word.as_str())
461 })
462 .map(|word| word.to_lowercase())
463 .collect()
464 }
465
466 fn calculate_confidence(&self, query: &str, query_type: &QueryType, _intent: &QueryIntent) -> f32 {
468 let mut confidence: f32 = 0.5; if self.has_clear_question_words(query) {
472 confidence += 0.2;
473 }
474
475 if Self::has_technical_terms(query) && matches!(query_type, QueryType::Technical) {
476 confidence += 0.2;
477 }
478
479 if query.ends_with('?') {
480 confidence += 0.1;
481 }
482
483 let word_count = query.split_whitespace().count();
485 if word_count < 3 || word_count > 50 {
486 confidence -= 0.1;
487 }
488
489 confidence.min(1.0_f32).max(0.0_f32)
490 }
491
492
493
494 fn has_code_patterns(&self, query: &str) -> bool {
496 let regexes = get_query_regexes();
497 regexes.code_function_call.is_match(query) ||
498 regexes.code_method_access.is_match(query) ||
499 regexes.code_punctuation.is_match(query) ||
500 regexes.code_keywords.is_match(query)
501 }
502
503 fn has_technical_terms(query: &str) -> bool {
505 TECHNICAL_DOMAINS.iter().any(|(_, terms)| {
506 terms.iter().any(|term| query.contains(term))
507 })
508 }
509
510 fn has_clear_question_words(&self, query: &str) -> bool {
512 QUESTION_WORDS.iter().any(|word| query.contains(word))
513 }
514
515 fn has_date_patterns(&self, query: &str) -> bool {
517 let regexes = get_query_regexes();
518 regexes.year_pattern.is_match(query) ||
519 regexes.date_pattern.is_match(query) ||
520 regexes.month_pattern.is_match(query)
521 }
522
523 fn extract_technical_terms(&self, query: &str) -> Vec<String> {
525 let mut terms = Vec::new();
526
527 for (_, domain_terms) in TECHNICAL_DOMAINS {
528 for term in *domain_terms {
529 if query.contains(term) {
530 terms.push(term.to_string());
531 }
532 }
533 }
534
535 terms
536 }
537}
538
539impl Default for QueryUnderstandingService {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_query_type_classification() {
551 let service = QueryUnderstandingService::new();
552
553 let understanding = service.understand_query("What is machine learning?").unwrap();
554 assert_eq!(understanding.query_type, QueryType::Definitional);
555
556 let understanding = service.understand_query("How to implement a neural network?").unwrap();
557 assert_eq!(understanding.query_type, QueryType::Procedural);
558
559 let understanding = service.understand_query("Compare React vs Vue").unwrap();
560 assert_eq!(understanding.query_type, QueryType::Comparative);
561 }
562
563 #[test]
564 fn test_intent_classification() {
565 let service = QueryUnderstandingService::new();
566
567 let understanding = service.understand_query("Explain how neural networks work").unwrap();
568 assert_eq!(understanding.intent, QueryIntent::Explain);
569
570 let understanding = service.understand_query("Help me debug this code").unwrap();
571 assert_eq!(understanding.intent, QueryIntent::Debug);
572
573 let understanding = service.understand_query("Show me the steps to install Python").unwrap();
574 assert_eq!(understanding.intent, QueryIntent::Guide);
575 }
576
577 #[test]
578 fn test_complexity_classification() {
579 let service = QueryUnderstandingService::new();
580
581 let understanding = service.understand_query("Hi").unwrap();
582 assert_eq!(understanding.complexity, QueryComplexity::Simple);
583
584 let understanding = service.understand_query("How do I implement a complex distributed system with microservices architecture?").unwrap();
585 assert_eq!(understanding.complexity, QueryComplexity::Complex);
586 }
587
588 #[test]
589 fn test_domain_classification() {
590 let service = QueryUnderstandingService::new();
591
592 let understanding = service.understand_query("How to train a machine learning model?").unwrap();
593 assert_eq!(understanding.domain.primary_domain, "machine_learning");
594
595 let understanding = service.understand_query("Write a JavaScript function").unwrap();
596 assert_eq!(understanding.domain.primary_domain, "programming");
597 }
598
599 #[test]
600 fn test_feature_extraction() {
601 let service = QueryUnderstandingService::new();
602
603 let understanding = service.understand_query("What is the function setTimeout() in JavaScript?").unwrap();
604 assert!(understanding.features.word_count > 0);
605 assert!(understanding.features.has_code);
606 assert!(!understanding.features.question_words.is_empty());
607 }
608
609 #[test]
610 fn test_keyword_extraction() {
611 let service = QueryUnderstandingService::new();
612
613 let understanding = service.understand_query("How to implement machine learning algorithms").unwrap();
614 assert!(understanding.keywords.contains(&"implement".to_string()));
615 assert!(understanding.keywords.contains(&"machine".to_string()));
616 assert!(understanding.keywords.contains(&"learning".to_string()));
617 assert!(understanding.keywords.contains(&"algorithms".to_string()));
618 }
619
620 #[test]
621 fn test_confidence_calculation() {
622 let service = QueryUnderstandingService::new();
623
624 let understanding = service.understand_query("What is machine learning?").unwrap();
625 assert!(understanding.confidence > 0.5);
626
627 let understanding = service.understand_query("a").unwrap();
628 assert!(understanding.confidence < 0.5);
629 }
630}