1use hirn_query::ast::RetrievalMode;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum QueryComplexity {
28 Simple,
30 Moderate,
32 Complex,
34}
35
36pub fn classify_and_route(
42 query: &str,
43 involving_count: usize,
44 where_count: usize,
45 has_temporal: bool,
46 has_expand: bool,
47 has_follow_causes: bool,
48) -> RetrievalMode {
49 let complexity = classify_query(
50 query,
51 involving_count,
52 where_count,
53 has_temporal,
54 has_expand,
55 has_follow_causes,
56 );
57
58 match complexity {
59 QueryComplexity::Simple => RetrievalMode::Local,
60 QueryComplexity::Moderate => RetrievalMode::Hybrid,
61 QueryComplexity::Complex => RetrievalMode::Raptor,
62 }
63}
64
65pub fn classify_query(
67 query: &str,
68 involving_count: usize,
69 where_count: usize,
70 has_temporal: bool,
71 has_expand: bool,
72 has_follow_causes: bool,
73) -> QueryComplexity {
74 let mut score: u32 = 0;
75
76 let token_count = query.split_whitespace().count();
78 if token_count >= 20 {
79 score += 3;
80 } else if token_count >= 10 {
81 score += 2;
82 } else if token_count >= 4 {
83 score += 1;
84 }
85
86 score += (where_count as u32).min(3);
88 if involving_count > 2 {
89 score += 2;
90 } else if involving_count > 0 {
91 score += 1;
92 }
93
94 let lower = query.to_lowercase();
96 let complex_patterns = [
97 "compare",
98 "contrast",
99 "why",
100 "how does",
101 "what caused",
102 "relationship between",
103 "difference between",
104 "trade-off",
105 "pros and cons",
106 "implications of",
107 "summarize all",
108 "overview of",
109 "explain the",
110 "analyze",
111 ];
112 let moderate_patterns = [
113 "how", "what are", "describe", "list", "when did", "where", "who", "which",
114 ];
115
116 let complex_hits = complex_patterns
117 .iter()
118 .filter(|p| lower.contains(*p))
119 .count();
120 let moderate_hits = moderate_patterns
121 .iter()
122 .filter(|p| lower.contains(*p))
123 .count();
124
125 score += (complex_hits as u32) * 2;
126 score += (moderate_hits as u32).min(2);
127
128 if has_temporal {
130 score += 2;
131 }
132
133 if has_expand {
135 score += 3;
136 }
137 if has_follow_causes {
138 score += 3;
139 }
140
141 if score >= 6 {
143 QueryComplexity::Complex
144 } else if score >= 3 {
145 QueryComplexity::Moderate
146 } else {
147 QueryComplexity::Simple
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn simple_factoid_query() {
157 let c = classify_query("what is JWT", 0, 0, false, false, false);
158 assert_eq!(c, QueryComplexity::Simple);
159 }
160
161 #[test]
162 fn moderate_query_with_entity() {
163 let c = classify_query(
164 "how does authentication work with OAuth tokens",
165 1,
166 0,
167 false,
168 false,
169 false,
170 );
171 assert_eq!(c, QueryComplexity::Moderate);
172 }
173
174 #[test]
175 fn complex_analytical_query() {
176 let c = classify_query(
177 "compare the trade-off between JWT and session-based authentication across all services",
178 3,
179 1,
180 false,
181 true,
182 false,
183 );
184 assert_eq!(c, QueryComplexity::Complex);
185 }
186
187 #[test]
188 fn temporal_adds_complexity() {
189 let c = classify_query("what happened with deployments", 0, 0, true, false, false);
190 assert_eq!(c, QueryComplexity::Moderate);
191 }
192
193 #[test]
194 fn follow_causes_is_complex() {
195 let c = classify_query("why did the service fail", 0, 0, false, false, true);
196 assert_eq!(c, QueryComplexity::Complex);
197 }
198
199 #[test]
200 fn classify_and_route_simple() {
201 let mode = classify_and_route("hello", 0, 0, false, false, false);
202 assert_eq!(mode, RetrievalMode::Local);
203 }
204
205 #[test]
206 fn classify_and_route_complex() {
207 let mode = classify_and_route(
208 "compare all authentication strategies and their trade-offs",
209 2,
210 1,
211 true,
212 true,
213 false,
214 );
215 assert_eq!(mode, RetrievalMode::Raptor);
216 }
217}