atomr_agents_deep_research_shell/
classifier.rs1use async_trait::async_trait;
12use atomr_agents_deep_research_core::ResearchRequest;
13
14use crate::error::Result;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum ResearchTier {
19 Shallow,
21 Deep,
23}
24
25#[async_trait]
27pub trait IntentClassifier: Send + Sync + 'static {
28 async fn classify(&self, req: &ResearchRequest) -> Result<ResearchTier>;
30}
31
32#[derive(Debug, Clone)]
53pub struct HeuristicIntentClassifier {
54 pub max_shallow_query_chars: usize,
56 pub max_shallow_question_marks: usize,
58 pub max_shallow_depth: u32,
60 pub comparative_markers: Vec<String>,
63}
64
65impl Default for HeuristicIntentClassifier {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl HeuristicIntentClassifier {
72 pub fn new() -> Self {
74 Self {
75 max_shallow_query_chars: 80,
76 max_shallow_question_marks: 1,
77 max_shallow_depth: 1,
78 comparative_markers: default_comparative_markers(),
79 }
80 }
81
82 pub fn with_max_shallow_query_chars(mut self, n: usize) -> Self {
84 self.max_shallow_query_chars = n;
85 self
86 }
87
88 pub fn with_max_shallow_question_marks(mut self, n: usize) -> Self {
90 self.max_shallow_question_marks = n;
91 self
92 }
93
94 pub fn with_max_shallow_depth(mut self, n: u32) -> Self {
96 self.max_shallow_depth = n;
97 self
98 }
99
100 pub fn with_comparative_markers<I, S>(mut self, markers: I) -> Self
102 where
103 I: IntoIterator<Item = S>,
104 S: Into<String>,
105 {
106 self.comparative_markers = markers.into_iter().map(Into::into).collect();
107 self
108 }
109
110 pub fn classify_sync(&self, req: &ResearchRequest) -> ResearchTier {
113 let query = req.query.as_str();
114
115 let char_count = query.chars().count();
117 if char_count >= self.max_shallow_query_chars {
118 return ResearchTier::Deep;
119 }
120
121 let qm = query.chars().filter(|c| *c == '?').count();
123 if qm > self.max_shallow_question_marks {
124 return ResearchTier::Deep;
125 }
126
127 if req.depth > self.max_shallow_depth {
129 return ResearchTier::Deep;
130 }
131
132 let lowered = query.to_lowercase();
134 for marker in &self.comparative_markers {
135 if lowered.contains(&marker.to_lowercase()) {
136 return ResearchTier::Deep;
137 }
138 }
139
140 ResearchTier::Shallow
141 }
142}
143
144#[async_trait]
145impl IntentClassifier for HeuristicIntentClassifier {
146 async fn classify(&self, req: &ResearchRequest) -> Result<ResearchTier> {
147 Ok(self.classify_sync(req))
148 }
149}
150
151fn default_comparative_markers() -> Vec<String> {
158 vec![
159 "compare".into(),
160 "versus".into(),
161 " vs ".into(),
162 " vs.".into(),
163 "trade-off".into(),
164 "tradeoff".into(),
165 "analyze".into(),
166 "deep dive".into(),
167 "research".into(),
168 "contrast".into(),
169 "differences between".into(),
170 "how do ".into(),
173 ]
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use atomr_agents_deep_research_core::ResearchRequest;
180
181 #[test]
182 fn defaults_route_short_queries_shallow() {
183 let c = HeuristicIntentClassifier::new();
184 let req = ResearchRequest::new("rust").with_depth(1);
185 assert_eq!(c.classify_sync(&req), ResearchTier::Shallow);
186 }
187
188 #[test]
189 fn comparative_markers_force_deep() {
190 let c = HeuristicIntentClassifier::new();
191 let req = ResearchRequest::new("compare actor frameworks");
192 assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
193 }
194
195 #[test]
196 fn depth_above_threshold_forces_deep() {
197 let c = HeuristicIntentClassifier::new();
198 let req = ResearchRequest::new("rust").with_depth(3);
199 assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
200 }
201
202 #[test]
203 fn long_queries_force_deep() {
204 let c = HeuristicIntentClassifier::new();
205 let long = "a".repeat(120);
206 let req = ResearchRequest::new(long).with_depth(0);
207 assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
208 }
209
210 #[test]
211 fn multiple_question_marks_force_deep() {
212 let c = HeuristicIntentClassifier::new();
213 let req = ResearchRequest::new("what? when? where?").with_depth(0);
214 assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
215 }
216
217 #[test]
218 fn builders_override_thresholds() {
219 let c = HeuristicIntentClassifier::new()
220 .with_max_shallow_query_chars(10)
221 .with_max_shallow_depth(0)
222 .with_comparative_markers(Vec::<String>::new());
223 let req = ResearchRequest::new("hello world").with_depth(0);
224 assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
226
227 let short = ResearchRequest::new("hi").with_depth(0);
228 assert_eq!(c.classify_sync(&short), ResearchTier::Shallow);
229 }
230}