1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum QueryType {
9 Math,
11 ShortFactoid,
13 Complex,
15 Unknown,
17}
18
19pub fn classify_query(query: &str) -> QueryType {
21 let lower = query.trim().to_lowercase();
22
23 if is_math_query(&lower) {
25 return QueryType::Math;
26 }
27
28 if is_short_factoid(&lower) {
30 return QueryType::ShortFactoid;
31 }
32
33 if is_complex_query(&lower) {
35 return QueryType::Complex;
36 }
37
38 QueryType::Unknown
39}
40
41fn is_math_query(query: &str) -> bool {
42 let has_math_ops = query.contains('+')
44 || query.contains('-')
45 || query.contains('*')
46 || query.contains('/')
47 || query.contains('^')
48 || query.contains('=');
49
50 let math_keywords = [
52 "calculate",
53 "compute",
54 "solve",
55 "equation",
56 "sum",
57 "multiply",
58 "divide",
59 "subtract",
60 "add",
61 "integral",
62 "derivative",
63 ];
64
65 has_math_ops || math_keywords.iter().any(|&kw| query.contains(kw))
66}
67
68fn is_short_factoid(query: &str) -> bool {
69 let question_starts = [
71 "what is",
72 "who is",
73 "when was",
74 "where is",
75 "which",
76 "define",
77 ];
78
79 let word_count = query.split_whitespace().count();
81
82 question_starts.iter().any(|&start| query.starts_with(start))
83 && word_count < 15
84}
85
86fn is_complex_query(query: &str) -> bool {
87 let complex_keywords = [
89 "explain",
90 "describe",
91 "analyze",
92 "compare",
93 "discuss",
94 "evaluate",
95 "how does",
96 "why does",
97 "tell me about",
98 "walk me through",
99 ];
100
101 let word_count = query.split_whitespace().count();
103
104 complex_keywords.iter().any(|&kw| query.contains(kw))
105 || word_count > 20
106 || query.contains('?') && word_count > 10
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_classify_math_queries() {
115 assert_eq!(classify_query("What is 5 + 3?"), QueryType::Math);
116 assert_eq!(classify_query("Calculate the sum of 10 and 20"), QueryType::Math);
117 assert_eq!(classify_query("Solve x^2 = 4"), QueryType::Math);
118 }
119
120 #[test]
121 fn test_classify_short_factoid() {
122 assert_eq!(classify_query("What is Rust?"), QueryType::ShortFactoid);
123 assert_eq!(classify_query("Who is the president?"), QueryType::ShortFactoid);
124 assert_eq!(classify_query("When was Python created?"), QueryType::ShortFactoid);
125 }
126
127 #[test]
128 fn test_classify_complex_queries() {
129 assert_eq!(
130 classify_query("Explain how async/await works in Rust"),
131 QueryType::Complex
132 );
133 assert_eq!(
134 classify_query("Tell me about the history of programming languages and their evolution over time"),
135 QueryType::Complex
136 );
137 assert_eq!(
138 classify_query("Why does the borrow checker prevent certain patterns?"),
139 QueryType::Complex
140 );
141 }
142
143 #[test]
144 fn test_classify_unknown() {
145 assert_eq!(classify_query("Hello"), QueryType::Unknown);
146 assert_eq!(classify_query(""), QueryType::Unknown);
147 }
148
149 #[test]
150 fn test_edge_cases() {
151 let long_what = "What is the meaning of life and how do we determine our purpose in this vast universe?";
153 assert_eq!(classify_query(long_what), QueryType::Complex);
154
155 assert_eq!(
157 classify_query("Explain how to solve quadratic equations"),
158 QueryType::Complex
159 );
160 }
161}