1use crate::RragResult;
7use serde::{Deserialize, Serialize};
8
9pub struct QueryClassifier {
11 patterns: Vec<IntentPattern>,
13
14 type_patterns: Vec<TypePattern>,
16}
17
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub enum QueryIntent {
21 Factual,
23 Conceptual,
25 Procedural,
27 Comparative,
29 Troubleshooting,
31 Exploratory,
33 Definitional,
35 OpinionSeeking,
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub enum QueryType {
42 Question,
44 Command,
46 Keywords,
48 Statement,
50 Complex,
52}
53
54#[derive(Debug, Clone)]
56pub struct ClassificationResult {
57 pub query: String,
59
60 pub intent: QueryIntent,
62
63 pub query_type: QueryType,
65
66 pub confidence: f32,
68
69 pub metadata: ClassificationMetadata,
71}
72
73#[derive(Debug, Clone)]
75pub struct ClassificationMetadata {
76 pub entities: Vec<String>,
78
79 pub domain: Option<String>,
81
82 pub complexity: f32,
84
85 pub needs_context: bool,
87
88 pub suggested_strategies: Vec<String>,
90}
91
92struct IntentPattern {
94 intent: QueryIntent,
96 keywords: Vec<String>,
98 patterns: Vec<String>,
100 confidence: f32,
102}
103
104struct TypePattern {
106 query_type: QueryType,
108 indicators: Vec<String>,
110 confidence: f32,
112}
113
114impl QueryClassifier {
115 pub fn new() -> Self {
117 let patterns = Self::init_intent_patterns();
118 let type_patterns = Self::init_type_patterns();
119
120 Self {
121 patterns,
122 type_patterns,
123 }
124 }
125
126 pub async fn classify(&self, query: &str) -> RragResult<ClassificationResult> {
128 let query_lower = query.to_lowercase();
129 let tokens = self.tokenize(&query_lower);
130
131 let (intent, intent_confidence) = self.detect_intent(&query_lower, &tokens);
133
134 let (query_type, type_confidence) = self.detect_query_type(&query_lower, &tokens);
136
137 let entities = self.extract_entities(&tokens);
139
140 let domain = self.detect_domain(&tokens);
142
143 let complexity = self.calculate_complexity(query, &tokens);
145
146 let needs_context = self.needs_context(query, &tokens);
148
149 let suggested_strategies = self.suggest_strategies(&intent, &query_type, complexity);
151
152 let confidence = intent_confidence.min(type_confidence);
154
155 Ok(ClassificationResult {
156 query: query.to_string(),
157 intent,
158 query_type,
159 confidence,
160 metadata: ClassificationMetadata {
161 entities,
162 domain,
163 complexity,
164 needs_context,
165 suggested_strategies,
166 },
167 })
168 }
169
170 fn detect_intent(&self, query: &str, tokens: &[String]) -> (QueryIntent, f32) {
172 let mut best_intent = QueryIntent::Factual;
173 let mut best_confidence = 0.0;
174
175 for pattern in &self.patterns {
176 let mut score = 0.0;
177 let mut matches = 0;
178
179 for keyword in &pattern.keywords {
181 if tokens.iter().any(|t| t.contains(keyword)) {
182 score += 1.0;
183 matches += 1;
184 }
185 }
186
187 for phrase in &pattern.patterns {
189 if query.contains(phrase) {
190 score += 2.0; matches += 1;
192 }
193 }
194
195 if matches > 0 {
196 let normalized_score = (score
198 / (pattern.keywords.len() + pattern.patterns.len()) as f32)
199 * pattern.confidence;
200
201 if normalized_score > best_confidence {
202 best_intent = pattern.intent.clone();
203 best_confidence = normalized_score;
204 }
205 }
206 }
207
208 if best_confidence < 0.3 {
210 if query.starts_with("what is") || query.starts_with("define") {
211 best_intent = QueryIntent::Definitional;
212 best_confidence = 0.6;
213 } else if query.starts_with("how to") || query.contains("step") {
214 best_intent = QueryIntent::Procedural;
215 best_confidence = 0.6;
216 } else if query.contains("compare")
217 || query.contains("vs")
218 || query.contains("difference")
219 {
220 best_intent = QueryIntent::Comparative;
221 best_confidence = 0.6;
222 }
223 }
224
225 (best_intent, best_confidence)
226 }
227
228 fn detect_query_type(&self, query: &str, tokens: &[String]) -> (QueryType, f32) {
230 let mut best_type = QueryType::Keywords;
231 let mut best_confidence = 0.0;
232
233 for pattern in &self.type_patterns {
234 let mut matches = 0;
235
236 for indicator in &pattern.indicators {
237 if query.contains(indicator) || tokens.iter().any(|t| t == indicator) {
238 matches += 1;
239 }
240 }
241
242 if matches > 0 {
243 let confidence =
244 (matches as f32 / pattern.indicators.len() as f32) * pattern.confidence;
245 if confidence > best_confidence {
246 best_type = pattern.query_type.clone();
247 best_confidence = confidence;
248 }
249 }
250 }
251
252 if best_confidence < 0.5 {
254 if query.ends_with('?') {
255 best_type = QueryType::Question;
256 best_confidence = 0.8;
257 } else if tokens.len() <= 3 {
258 best_type = QueryType::Keywords;
259 best_confidence = 0.7;
260 } else if tokens.len() > 10 {
261 best_type = QueryType::Complex;
262 best_confidence = 0.6;
263 } else {
264 best_type = QueryType::Statement;
265 best_confidence = 0.5;
266 }
267 }
268
269 (best_type, best_confidence)
270 }
271
272 fn extract_entities(&self, tokens: &[String]) -> Vec<String> {
274 let mut entities = Vec::new();
275
276 for token in tokens {
278 if token.chars().next().map_or(false, |c| c.is_uppercase()) {
280 entities.push(token.clone());
281 }
282
283 let tech_terms = [
285 "api",
286 "sql",
287 "json",
288 "html",
289 "css",
290 "javascript",
291 "python",
292 "rust",
293 "docker",
294 ];
295 if tech_terms.contains(&token.to_lowercase().as_str()) {
296 entities.push(token.clone());
297 }
298 }
299
300 entities
301 }
302
303 fn detect_domain(&self, tokens: &[String]) -> Option<String> {
305 let domains = [
306 (
307 "technology",
308 vec![
309 "code",
310 "programming",
311 "software",
312 "api",
313 "database",
314 "algorithm",
315 "computer",
316 ],
317 ),
318 (
319 "science",
320 vec![
321 "research",
322 "study",
323 "experiment",
324 "theory",
325 "analysis",
326 "data",
327 "scientific",
328 ],
329 ),
330 (
331 "business",
332 vec![
333 "market",
334 "sales",
335 "revenue",
336 "customer",
337 "profit",
338 "strategy",
339 "management",
340 ],
341 ),
342 (
343 "health",
344 vec![
345 "medical",
346 "health",
347 "disease",
348 "treatment",
349 "doctor",
350 "medicine",
351 "patient",
352 ],
353 ),
354 (
355 "education",
356 vec![
357 "learn",
358 "study",
359 "school",
360 "university",
361 "course",
362 "education",
363 "teach",
364 ],
365 ),
366 ];
367
368 for (domain, keywords) in &domains {
369 let matches = keywords
370 .iter()
371 .filter(|&&keyword| tokens.iter().any(|t| t.contains(keyword)))
372 .count();
373
374 if matches >= 2 || (matches == 1 && tokens.len() <= 5) {
375 return Some(domain.to_string());
376 }
377 }
378
379 None
380 }
381
382 fn calculate_complexity(&self, query: &str, tokens: &[String]) -> f32 {
384 let mut complexity = 0.0;
385
386 complexity += (tokens.len() as f32 / 20.0).min(1.0) * 0.3;
388
389 let question_words = ["what", "how", "why", "when", "where", "which", "who"];
391 let question_count = question_words
392 .iter()
393 .filter(|&&word| tokens.iter().any(|t| t == word))
394 .count();
395 complexity += (question_count as f32 * 0.1).min(0.3);
396
397 let conjunctions = ["and", "or", "but", "however", "also", "additionally"];
399 let conjunction_count = conjunctions
400 .iter()
401 .filter(|&&word| tokens.iter().any(|t| t == word))
402 .count();
403 complexity += (conjunction_count as f32 * 0.15).min(0.2);
404
405 if query.matches('?').count() > 1 {
407 complexity += 0.2;
408 }
409
410 complexity.min(1.0)
411 }
412
413 fn needs_context(&self, _query: &str, tokens: &[String]) -> bool {
415 let context_indicators = [
416 "this",
417 "that",
418 "it",
419 "they",
420 "them",
421 "previous",
422 "above",
423 "following",
424 ];
425 let pronouns = ["it", "this", "that", "these", "those"];
426
427 let has_pronouns = pronouns
429 .iter()
430 .any(|&pronoun| tokens.contains(&pronoun.to_string()));
431
432 let has_context_indicators = context_indicators
434 .iter()
435 .any(|&indicator| tokens.contains(&indicator.to_string()));
436
437 let is_very_short = tokens.len() <= 2;
439
440 has_pronouns || has_context_indicators || is_very_short
441 }
442
443 fn suggest_strategies(
445 &self,
446 intent: &QueryIntent,
447 query_type: &QueryType,
448 complexity: f32,
449 ) -> Vec<String> {
450 let mut strategies = Vec::new();
451
452 match intent {
453 QueryIntent::Factual => {
454 strategies.push("keyword_search".to_string());
455 strategies.push("exact_match".to_string());
456 }
457 QueryIntent::Conceptual => {
458 strategies.push("semantic_search".to_string());
459 strategies.push("related_documents".to_string());
460 }
461 QueryIntent::Procedural => {
462 strategies.push("step_by_step".to_string());
463 strategies.push("tutorial_search".to_string());
464 }
465 QueryIntent::Comparative => {
466 strategies.push("comparative_analysis".to_string());
467 strategies.push("side_by_side".to_string());
468 }
469 QueryIntent::Troubleshooting => {
470 strategies.push("problem_solution".to_string());
471 strategies.push("diagnostic".to_string());
472 }
473 QueryIntent::Exploratory => {
474 strategies.push("broad_search".to_string());
475 strategies.push("topic_exploration".to_string());
476 }
477 QueryIntent::Definitional => {
478 strategies.push("definition_search".to_string());
479 strategies.push("glossary_lookup".to_string());
480 }
481 QueryIntent::OpinionSeeking => {
482 strategies.push("review_search".to_string());
483 strategies.push("opinion_mining".to_string());
484 }
485 }
486
487 match query_type {
488 QueryType::Complex => {
489 strategies.push("query_decomposition".to_string());
490 strategies.push("multi_step_search".to_string());
491 }
492 QueryType::Keywords => {
493 strategies.push("keyword_expansion".to_string());
494 strategies.push("term_matching".to_string());
495 }
496 _ => {}
497 }
498
499 if complexity > 0.7 {
500 strategies.push("complex_reasoning".to_string());
501 strategies.push("multi_document_synthesis".to_string());
502 }
503
504 strategies
505 }
506
507 fn tokenize(&self, query: &str) -> Vec<String> {
509 query
510 .split_whitespace()
511 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
512 .filter(|s| !s.is_empty())
513 .map(|s| s.to_lowercase())
514 .collect()
515 }
516
517 fn init_intent_patterns() -> Vec<IntentPattern> {
519 vec![
520 IntentPattern {
521 intent: QueryIntent::Definitional,
522 keywords: vec![
523 "define".to_string(),
524 "definition".to_string(),
525 "meaning".to_string(),
526 ],
527 patterns: vec![
528 "what is".to_string(),
529 "what does".to_string(),
530 "define".to_string(),
531 ],
532 confidence: 0.9,
533 },
534 IntentPattern {
535 intent: QueryIntent::Procedural,
536 keywords: vec![
537 "how".to_string(),
538 "step".to_string(),
539 "tutorial".to_string(),
540 "guide".to_string(),
541 ],
542 patterns: vec![
543 "how to".to_string(),
544 "step by step".to_string(),
545 "how do i".to_string(),
546 ],
547 confidence: 0.9,
548 },
549 IntentPattern {
550 intent: QueryIntent::Comparative,
551 keywords: vec![
552 "compare".to_string(),
553 "difference".to_string(),
554 "better".to_string(),
555 "versus".to_string(),
556 ],
557 patterns: vec![
558 "vs".to_string(),
559 "compared to".to_string(),
560 "difference between".to_string(),
561 ],
562 confidence: 0.8,
563 },
564 IntentPattern {
565 intent: QueryIntent::Troubleshooting,
566 keywords: vec![
567 "problem".to_string(),
568 "error".to_string(),
569 "fix".to_string(),
570 "issue".to_string(),
571 "broken".to_string(),
572 ],
573 patterns: vec![
574 "not working".to_string(),
575 "how to fix".to_string(),
576 "troubleshoot".to_string(),
577 ],
578 confidence: 0.8,
579 },
580 IntentPattern {
581 intent: QueryIntent::Factual,
582 keywords: vec![
583 "when".to_string(),
584 "where".to_string(),
585 "who".to_string(),
586 "which".to_string(),
587 ],
588 patterns: vec![
589 "when did".to_string(),
590 "where is".to_string(),
591 "who created".to_string(),
592 ],
593 confidence: 0.7,
594 },
595 ]
596 }
597
598 fn init_type_patterns() -> Vec<TypePattern> {
600 vec![
601 TypePattern {
602 query_type: QueryType::Question,
603 indicators: vec![
604 "?".to_string(),
605 "what".to_string(),
606 "how".to_string(),
607 "why".to_string(),
608 "when".to_string(),
609 "where".to_string(),
610 ],
611 confidence: 0.9,
612 },
613 TypePattern {
614 query_type: QueryType::Command,
615 indicators: vec![
616 "show".to_string(),
617 "find".to_string(),
618 "get".to_string(),
619 "list".to_string(),
620 "give".to_string(),
621 ],
622 confidence: 0.8,
623 },
624 TypePattern {
625 query_type: QueryType::Complex,
626 indicators: vec![
627 "and".to_string(),
628 "or".to_string(),
629 "but".to_string(),
630 "however".to_string(),
631 "also".to_string(),
632 ],
633 confidence: 0.7,
634 },
635 ]
636 }
637}
638
639impl Default for QueryClassifier {
640 fn default() -> Self {
641 Self::new()
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[tokio::test]
650 async fn test_definitional_query() {
651 let classifier = QueryClassifier::new();
652
653 let result = classifier
654 .classify("What is machine learning?")
655 .await
656 .unwrap();
657 assert_eq!(result.intent, QueryIntent::Definitional);
658 assert_eq!(result.query_type, QueryType::Question);
659 assert!(result.confidence > 0.5);
660 }
661
662 #[tokio::test]
663 async fn test_procedural_query() {
664 let classifier = QueryClassifier::new();
665
666 let result = classifier
667 .classify("How to implement a REST API?")
668 .await
669 .unwrap();
670 assert_eq!(result.intent, QueryIntent::Procedural);
671 assert!(result.confidence > 0.5);
672 }
673
674 #[tokio::test]
675 async fn test_comparative_query() {
676 let classifier = QueryClassifier::new();
677
678 let result = classifier
679 .classify("Python vs Rust performance comparison")
680 .await
681 .unwrap();
682 assert_eq!(result.intent, QueryIntent::Comparative);
683 assert!(result.confidence > 0.5);
684 }
685}