rexis_rag/query/
classifier.rs

1//! # Query Classifier
2//!
3//! Intelligent classification of user queries to determine intent and appropriate search strategies.
4//! Helps optimize retrieval by understanding what the user is looking for.
5
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8
9/// Query classifier for intent detection
10pub struct QueryClassifier {
11    /// Intent patterns for classification
12    patterns: Vec<IntentPattern>,
13
14    /// Type patterns for classification
15    type_patterns: Vec<TypePattern>,
16}
17
18/// Query intent categories
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub enum QueryIntent {
21    /// Factual information seeking
22    Factual,
23    /// Conceptual understanding
24    Conceptual,
25    /// Procedural how-to questions
26    Procedural,
27    /// Comparative analysis
28    Comparative,
29    /// Troubleshooting and problem-solving
30    Troubleshooting,
31    /// Exploratory research
32    Exploratory,
33    /// Definitional queries
34    Definitional,
35    /// Opinion or recommendation seeking
36    OpinionSeeking,
37}
38
39/// Query type categories
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub enum QueryType {
42    /// Direct question
43    Question,
44    /// Command or request
45    Command,
46    /// Keyword search
47    Keywords,
48    /// Natural language statement
49    Statement,
50    /// Complex multi-part query
51    Complex,
52}
53
54/// Result of query classification
55#[derive(Debug, Clone)]
56pub struct ClassificationResult {
57    /// Original query
58    pub query: String,
59
60    /// Detected intent
61    pub intent: QueryIntent,
62
63    /// Detected query type
64    pub query_type: QueryType,
65
66    /// Confidence score (0.0 to 1.0)
67    pub confidence: f32,
68
69    /// Additional metadata about the query
70    pub metadata: ClassificationMetadata,
71}
72
73/// Additional metadata from classification
74#[derive(Debug, Clone)]
75pub struct ClassificationMetadata {
76    /// Key entities detected in the query
77    pub entities: Vec<String>,
78
79    /// Domain/topic detected
80    pub domain: Option<String>,
81
82    /// Complexity score
83    pub complexity: f32,
84
85    /// Whether query requires context
86    pub needs_context: bool,
87
88    /// Suggested search strategies
89    pub suggested_strategies: Vec<String>,
90}
91
92/// Pattern for intent detection
93struct IntentPattern {
94    /// Intent this pattern detects
95    intent: QueryIntent,
96    /// Keywords that indicate this intent
97    keywords: Vec<String>,
98    /// Phrase patterns
99    patterns: Vec<String>,
100    /// Confidence score
101    confidence: f32,
102}
103
104/// Pattern for type detection
105struct TypePattern {
106    /// Query type this pattern detects
107    query_type: QueryType,
108    /// Indicators for this type
109    indicators: Vec<String>,
110    /// Confidence score
111    confidence: f32,
112}
113
114impl QueryClassifier {
115    /// Create a new query classifier
116    pub fn new() -> Self {
117        let patterns = Self::init_intent_patterns();
118        let type_patterns = Self::init_type_patterns();
119
120        Self {
121            patterns,
122            type_patterns,
123        }
124    }
125
126    /// Classify a query to determine intent and type
127    pub async fn classify(&self, query: &str) -> RragResult<ClassificationResult> {
128        let query_lower = query.to_lowercase();
129        let tokens = self.tokenize(&query_lower);
130
131        // Detect intent
132        let (intent, intent_confidence) = self.detect_intent(&query_lower, &tokens);
133
134        // Detect query type
135        let (query_type, type_confidence) = self.detect_query_type(&query_lower, &tokens);
136
137        // Extract entities
138        let entities = self.extract_entities(&tokens);
139
140        // Detect domain
141        let domain = self.detect_domain(&tokens);
142
143        // Calculate complexity
144        let complexity = self.calculate_complexity(query, &tokens);
145
146        // Determine if context is needed
147        let needs_context = self.needs_context(query, &tokens);
148
149        // Suggest search strategies
150        let suggested_strategies = self.suggest_strategies(&intent, &query_type, complexity);
151
152        // Overall confidence is the minimum of intent and type confidence
153        let confidence = intent_confidence.min(type_confidence);
154
155        Ok(ClassificationResult {
156            query: query.to_string(),
157            intent,
158            query_type,
159            confidence,
160            metadata: ClassificationMetadata {
161                entities,
162                domain,
163                complexity,
164                needs_context,
165                suggested_strategies,
166            },
167        })
168    }
169
170    /// Detect query intent
171    fn detect_intent(&self, query: &str, tokens: &[String]) -> (QueryIntent, f32) {
172        let mut best_intent = QueryIntent::Factual;
173        let mut best_confidence = 0.0;
174
175        for pattern in &self.patterns {
176            let mut score = 0.0;
177            let mut matches = 0;
178
179            // Check keyword matches
180            for keyword in &pattern.keywords {
181                if tokens.iter().any(|t| t.contains(keyword)) {
182                    score += 1.0;
183                    matches += 1;
184                }
185            }
186
187            // Check phrase patterns
188            for phrase in &pattern.patterns {
189                if query.contains(phrase) {
190                    score += 2.0; // Phrase matches are stronger
191                    matches += 1;
192                }
193            }
194
195            if matches > 0 {
196                // Normalize score
197                let normalized_score = (score
198                    / (pattern.keywords.len() + pattern.patterns.len()) as f32)
199                    * pattern.confidence;
200
201                if normalized_score > best_confidence {
202                    best_intent = pattern.intent.clone();
203                    best_confidence = normalized_score;
204                }
205            }
206        }
207
208        // Default fallback based on simple heuristics
209        if best_confidence < 0.3 {
210            if query.starts_with("what is") || query.starts_with("define") {
211                best_intent = QueryIntent::Definitional;
212                best_confidence = 0.6;
213            } else if query.starts_with("how to") || query.contains("step") {
214                best_intent = QueryIntent::Procedural;
215                best_confidence = 0.6;
216            } else if query.contains("compare")
217                || query.contains("vs")
218                || query.contains("difference")
219            {
220                best_intent = QueryIntent::Comparative;
221                best_confidence = 0.6;
222            }
223        }
224
225        (best_intent, best_confidence)
226    }
227
228    /// Detect query type
229    fn detect_query_type(&self, query: &str, tokens: &[String]) -> (QueryType, f32) {
230        let mut best_type = QueryType::Keywords;
231        let mut best_confidence = 0.0;
232
233        for pattern in &self.type_patterns {
234            let mut matches = 0;
235
236            for indicator in &pattern.indicators {
237                if query.contains(indicator) || tokens.iter().any(|t| t == indicator) {
238                    matches += 1;
239                }
240            }
241
242            if matches > 0 {
243                let confidence =
244                    (matches as f32 / pattern.indicators.len() as f32) * pattern.confidence;
245                if confidence > best_confidence {
246                    best_type = pattern.query_type.clone();
247                    best_confidence = confidence;
248                }
249            }
250        }
251
252        // Fallback heuristics
253        if best_confidence < 0.5 {
254            if query.ends_with('?') {
255                best_type = QueryType::Question;
256                best_confidence = 0.8;
257            } else if tokens.len() <= 3 {
258                best_type = QueryType::Keywords;
259                best_confidence = 0.7;
260            } else if tokens.len() > 10 {
261                best_type = QueryType::Complex;
262                best_confidence = 0.6;
263            } else {
264                best_type = QueryType::Statement;
265                best_confidence = 0.5;
266            }
267        }
268
269        (best_type, best_confidence)
270    }
271
272    /// Extract entities from query tokens
273    fn extract_entities(&self, tokens: &[String]) -> Vec<String> {
274        let mut entities = Vec::new();
275
276        // Simple entity extraction based on capitalization and known patterns
277        for token in tokens {
278            // Check if it looks like a proper noun (capitalized)
279            if token.chars().next().map_or(false, |c| c.is_uppercase()) {
280                entities.push(token.clone());
281            }
282
283            // Check for technical terms
284            let tech_terms = [
285                "api",
286                "sql",
287                "json",
288                "html",
289                "css",
290                "javascript",
291                "python",
292                "rust",
293                "docker",
294            ];
295            if tech_terms.contains(&token.to_lowercase().as_str()) {
296                entities.push(token.clone());
297            }
298        }
299
300        entities
301    }
302
303    /// Detect query domain
304    fn detect_domain(&self, tokens: &[String]) -> Option<String> {
305        let domains = [
306            (
307                "technology",
308                vec![
309                    "code",
310                    "programming",
311                    "software",
312                    "api",
313                    "database",
314                    "algorithm",
315                    "computer",
316                ],
317            ),
318            (
319                "science",
320                vec![
321                    "research",
322                    "study",
323                    "experiment",
324                    "theory",
325                    "analysis",
326                    "data",
327                    "scientific",
328                ],
329            ),
330            (
331                "business",
332                vec![
333                    "market",
334                    "sales",
335                    "revenue",
336                    "customer",
337                    "profit",
338                    "strategy",
339                    "management",
340                ],
341            ),
342            (
343                "health",
344                vec![
345                    "medical",
346                    "health",
347                    "disease",
348                    "treatment",
349                    "doctor",
350                    "medicine",
351                    "patient",
352                ],
353            ),
354            (
355                "education",
356                vec![
357                    "learn",
358                    "study",
359                    "school",
360                    "university",
361                    "course",
362                    "education",
363                    "teach",
364                ],
365            ),
366        ];
367
368        for (domain, keywords) in &domains {
369            let matches = keywords
370                .iter()
371                .filter(|&&keyword| tokens.iter().any(|t| t.contains(keyword)))
372                .count();
373
374            if matches >= 2 || (matches == 1 && tokens.len() <= 5) {
375                return Some(domain.to_string());
376            }
377        }
378
379        None
380    }
381
382    /// Calculate query complexity
383    fn calculate_complexity(&self, query: &str, tokens: &[String]) -> f32 {
384        let mut complexity = 0.0;
385
386        // Length factor
387        complexity += (tokens.len() as f32 / 20.0).min(1.0) * 0.3;
388
389        // Question words
390        let question_words = ["what", "how", "why", "when", "where", "which", "who"];
391        let question_count = question_words
392            .iter()
393            .filter(|&&word| tokens.iter().any(|t| t == word))
394            .count();
395        complexity += (question_count as f32 * 0.1).min(0.3);
396
397        // Conjunctions indicating complexity
398        let conjunctions = ["and", "or", "but", "however", "also", "additionally"];
399        let conjunction_count = conjunctions
400            .iter()
401            .filter(|&&word| tokens.iter().any(|t| t == word))
402            .count();
403        complexity += (conjunction_count as f32 * 0.15).min(0.2);
404
405        // Nested questions
406        if query.matches('?').count() > 1 {
407            complexity += 0.2;
408        }
409
410        complexity.min(1.0)
411    }
412
413    /// Determine if query needs conversational context
414    fn needs_context(&self, _query: &str, tokens: &[String]) -> bool {
415        let context_indicators = [
416            "this",
417            "that",
418            "it",
419            "they",
420            "them",
421            "previous",
422            "above",
423            "following",
424        ];
425        let pronouns = ["it", "this", "that", "these", "those"];
426
427        // Check for pronouns without clear antecedents
428        let has_pronouns = pronouns
429            .iter()
430            .any(|&pronoun| tokens.contains(&pronoun.to_string()));
431
432        // Check for context indicators
433        let has_context_indicators = context_indicators
434            .iter()
435            .any(|&indicator| tokens.contains(&indicator.to_string()));
436
437        // Very short queries often need context
438        let is_very_short = tokens.len() <= 2;
439
440        has_pronouns || has_context_indicators || is_very_short
441    }
442
443    /// Suggest search strategies based on classification
444    fn suggest_strategies(
445        &self,
446        intent: &QueryIntent,
447        query_type: &QueryType,
448        complexity: f32,
449    ) -> Vec<String> {
450        let mut strategies = Vec::new();
451
452        match intent {
453            QueryIntent::Factual => {
454                strategies.push("keyword_search".to_string());
455                strategies.push("exact_match".to_string());
456            }
457            QueryIntent::Conceptual => {
458                strategies.push("semantic_search".to_string());
459                strategies.push("related_documents".to_string());
460            }
461            QueryIntent::Procedural => {
462                strategies.push("step_by_step".to_string());
463                strategies.push("tutorial_search".to_string());
464            }
465            QueryIntent::Comparative => {
466                strategies.push("comparative_analysis".to_string());
467                strategies.push("side_by_side".to_string());
468            }
469            QueryIntent::Troubleshooting => {
470                strategies.push("problem_solution".to_string());
471                strategies.push("diagnostic".to_string());
472            }
473            QueryIntent::Exploratory => {
474                strategies.push("broad_search".to_string());
475                strategies.push("topic_exploration".to_string());
476            }
477            QueryIntent::Definitional => {
478                strategies.push("definition_search".to_string());
479                strategies.push("glossary_lookup".to_string());
480            }
481            QueryIntent::OpinionSeeking => {
482                strategies.push("review_search".to_string());
483                strategies.push("opinion_mining".to_string());
484            }
485        }
486
487        match query_type {
488            QueryType::Complex => {
489                strategies.push("query_decomposition".to_string());
490                strategies.push("multi_step_search".to_string());
491            }
492            QueryType::Keywords => {
493                strategies.push("keyword_expansion".to_string());
494                strategies.push("term_matching".to_string());
495            }
496            _ => {}
497        }
498
499        if complexity > 0.7 {
500            strategies.push("complex_reasoning".to_string());
501            strategies.push("multi_document_synthesis".to_string());
502        }
503
504        strategies
505    }
506
507    /// Tokenize query
508    fn tokenize(&self, query: &str) -> Vec<String> {
509        query
510            .split_whitespace()
511            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
512            .filter(|s| !s.is_empty())
513            .map(|s| s.to_lowercase())
514            .collect()
515    }
516
517    /// Initialize intent detection patterns
518    fn init_intent_patterns() -> Vec<IntentPattern> {
519        vec![
520            IntentPattern {
521                intent: QueryIntent::Definitional,
522                keywords: vec![
523                    "define".to_string(),
524                    "definition".to_string(),
525                    "meaning".to_string(),
526                ],
527                patterns: vec![
528                    "what is".to_string(),
529                    "what does".to_string(),
530                    "define".to_string(),
531                ],
532                confidence: 0.9,
533            },
534            IntentPattern {
535                intent: QueryIntent::Procedural,
536                keywords: vec![
537                    "how".to_string(),
538                    "step".to_string(),
539                    "tutorial".to_string(),
540                    "guide".to_string(),
541                ],
542                patterns: vec![
543                    "how to".to_string(),
544                    "step by step".to_string(),
545                    "how do i".to_string(),
546                ],
547                confidence: 0.9,
548            },
549            IntentPattern {
550                intent: QueryIntent::Comparative,
551                keywords: vec![
552                    "compare".to_string(),
553                    "difference".to_string(),
554                    "better".to_string(),
555                    "versus".to_string(),
556                ],
557                patterns: vec![
558                    "vs".to_string(),
559                    "compared to".to_string(),
560                    "difference between".to_string(),
561                ],
562                confidence: 0.8,
563            },
564            IntentPattern {
565                intent: QueryIntent::Troubleshooting,
566                keywords: vec![
567                    "problem".to_string(),
568                    "error".to_string(),
569                    "fix".to_string(),
570                    "issue".to_string(),
571                    "broken".to_string(),
572                ],
573                patterns: vec![
574                    "not working".to_string(),
575                    "how to fix".to_string(),
576                    "troubleshoot".to_string(),
577                ],
578                confidence: 0.8,
579            },
580            IntentPattern {
581                intent: QueryIntent::Factual,
582                keywords: vec![
583                    "when".to_string(),
584                    "where".to_string(),
585                    "who".to_string(),
586                    "which".to_string(),
587                ],
588                patterns: vec![
589                    "when did".to_string(),
590                    "where is".to_string(),
591                    "who created".to_string(),
592                ],
593                confidence: 0.7,
594            },
595        ]
596    }
597
598    /// Initialize query type patterns
599    fn init_type_patterns() -> Vec<TypePattern> {
600        vec![
601            TypePattern {
602                query_type: QueryType::Question,
603                indicators: vec![
604                    "?".to_string(),
605                    "what".to_string(),
606                    "how".to_string(),
607                    "why".to_string(),
608                    "when".to_string(),
609                    "where".to_string(),
610                ],
611                confidence: 0.9,
612            },
613            TypePattern {
614                query_type: QueryType::Command,
615                indicators: vec![
616                    "show".to_string(),
617                    "find".to_string(),
618                    "get".to_string(),
619                    "list".to_string(),
620                    "give".to_string(),
621                ],
622                confidence: 0.8,
623            },
624            TypePattern {
625                query_type: QueryType::Complex,
626                indicators: vec![
627                    "and".to_string(),
628                    "or".to_string(),
629                    "but".to_string(),
630                    "however".to_string(),
631                    "also".to_string(),
632                ],
633                confidence: 0.7,
634            },
635        ]
636    }
637}
638
639impl Default for QueryClassifier {
640    fn default() -> Self {
641        Self::new()
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[tokio::test]
650    async fn test_definitional_query() {
651        let classifier = QueryClassifier::new();
652
653        let result = classifier
654            .classify("What is machine learning?")
655            .await
656            .unwrap();
657        assert_eq!(result.intent, QueryIntent::Definitional);
658        assert_eq!(result.query_type, QueryType::Question);
659        assert!(result.confidence > 0.5);
660    }
661
662    #[tokio::test]
663    async fn test_procedural_query() {
664        let classifier = QueryClassifier::new();
665
666        let result = classifier
667            .classify("How to implement a REST API?")
668            .await
669            .unwrap();
670        assert_eq!(result.intent, QueryIntent::Procedural);
671        assert!(result.confidence > 0.5);
672    }
673
674    #[tokio::test]
675    async fn test_comparative_query() {
676        let classifier = QueryClassifier::new();
677
678        let result = classifier
679            .classify("Python vs Rust performance comparison")
680            .await
681            .unwrap();
682        assert_eq!(result.intent, QueryIntent::Comparative);
683        assert!(result.confidence > 0.5);
684    }
685}