Skip to main content

rs_agent/
query.rs

1//! Query classification utilities
2//!
3//! This module provides query type detection to optimize context retrieval,
4//! matching the structure from go-agent's query.go.
5
6/// Types of queries that determine retrieval strategy
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum QueryType {
9    /// Mathematical or computational queries that don't need context
10    Math,
11    /// Short factoid questions needing minimal context
12    ShortFactoid,
13    /// Complex queries requiring full context retrieval
14    Complex,
15    /// Unknown query type, skip retrieval
16    Unknown,
17}
18
19/// Classifies a query to determine optimal retrieval strategy
20pub fn classify_query(query: &str) -> QueryType {
21    let lower = query.trim().to_lowercase();
22    
23    // Math queries - numerical operations, calculations
24    if is_math_query(&lower) {
25        return QueryType::Math;
26    }
27    
28    // Short factoid queries - simple definitions, factual questions
29    if is_short_factoid(&lower) {
30        return QueryType::ShortFactoid;
31    }
32    
33    // Complex queries - explanations, multi-step reasoning
34    if is_complex_query(&lower) {
35        return QueryType::Complex;
36    }
37    
38    QueryType::Unknown
39}
40
41fn is_math_query(query: &str) -> bool {
42    // Mathematical operators
43    let has_math_ops = query.contains('+')
44        || query.contains('-')
45        || query.contains('*')
46        || query.contains('/')
47        || query.contains('^')
48        || query.contains('=');
49    
50    // Mathematical keywords
51    let math_keywords = [
52        "calculate",
53        "compute",
54        "solve",
55        "equation",
56        "sum",
57        "multiply",
58        "divide",
59        "subtract",
60        "add",
61        "integral",
62        "derivative",
63    ];
64    
65    has_math_ops || math_keywords.iter().any(|&kw| query.contains(kw))
66}
67
68fn is_short_factoid(query: &str) -> bool {
69    // Short questions typically start with question words
70    let question_starts = [
71        "what is",
72        "who is",
73        "when was",
74        "where is",
75        "which",
76        "define",
77    ];
78    
79    // Must be relatively short
80    let word_count = query.split_whitespace().count();
81    
82    question_starts.iter().any(|&start| query.starts_with(start)) 
83        && word_count < 15
84}
85
86fn is_complex_query(query: &str) -> bool {
87    // Complexity indicators
88    let complex_keywords = [
89        "explain",
90        "describe",
91        "analyze",
92        "compare",
93        "discuss",
94        "evaluate",
95        "how does",
96        "why does",
97        "tell me about",
98        "walk me through",
99    ];
100    
101    // Longer queries are typically more complex
102    let word_count = query.split_whitespace().count();
103    
104    complex_keywords.iter().any(|&kw| query.contains(kw))
105        || word_count > 20
106        || query.contains('?') && word_count > 10
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_classify_math_queries() {
115        assert_eq!(classify_query("What is 5 + 3?"), QueryType::Math);
116        assert_eq!(classify_query("Calculate the sum of 10 and 20"), QueryType::Math);
117        assert_eq!(classify_query("Solve x^2 = 4"), QueryType::Math);
118    }
119
120    #[test]
121    fn test_classify_short_factoid() {
122        assert_eq!(classify_query("What is Rust?"), QueryType::ShortFactoid);
123        assert_eq!(classify_query("Who is the president?"), QueryType::ShortFactoid);
124        assert_eq!(classify_query("When was Python created?"), QueryType::ShortFactoid);
125    }
126
127    #[test]
128    fn test_classify_complex_queries() {
129        assert_eq!(
130            classify_query("Explain how async/await works in Rust"),
131            QueryType::Complex
132        );
133        assert_eq!(
134            classify_query("Tell me about the history of programming languages and their evolution over time"),
135            QueryType::Complex
136        );
137        assert_eq!(
138            classify_query("Why does the borrow checker prevent certain patterns?"),
139            QueryType::Complex
140        );
141    }
142
143    #[test]
144    fn test_classify_unknown() {
145        assert_eq!(classify_query("Hello"), QueryType::Unknown);
146        assert_eq!(classify_query(""), QueryType::Unknown);
147    }
148
149    #[test]
150    fn test_edge_cases() {
151        // "What is" but too long for short factoid
152        let long_what = "What is the meaning of life and how do we determine our purpose in this vast universe?";
153        assert_eq!(classify_query(long_what), QueryType::Complex);
154        
155        // Math with explanation request
156        assert_eq!(
157            classify_query("Explain how to solve quadratic equations"),
158            QueryType::Complex
159        );
160    }
161}