Skip to main content

atomr_agents_deep_research_shell/
classifier.rs

1//! Intent classifier: route an incoming [`ResearchRequest`] to either
2//! the fast shallow path or the full deep-research harness.
3//!
4//! The default deterministic implementation is
5//! [`HeuristicIntentClassifier`]; an LLM-backed
6//! `AgentBasedIntentClassifier` would slot in here once the `agent`
7//! feature on the deep-research-harness lands (see PR 4 of the v2
8//! plan). The trait is intentionally object-safe so callers can swap
9//! implementations behind an `Arc<dyn IntentClassifier>`.
10
11use async_trait::async_trait;
12use atomr_agents_deep_research_core::ResearchRequest;
13
14use crate::error::Result;
15
16/// Which tier should service a given [`ResearchRequest`].
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum ResearchTier {
19    /// Single web-search round, no clarifier / planner / critic loop.
20    Shallow,
21    /// Full deep-research harness pipeline.
22    Deep,
23}
24
25/// Object-safe trait every intent classifier implements.
26#[async_trait]
27pub trait IntentClassifier: Send + Sync + 'static {
28    /// Decide which tier should handle `req`.
29    async fn classify(&self, req: &ResearchRequest) -> Result<ResearchTier>;
30}
31
32// NOTE: an `AgentBasedIntentClassifier` (LLM-driven, structured-output)
33// would live alongside this module guarded by the deep-research-harness
34// `agent` feature. PR 4 of the v2 plan introduces that feature and the
35// `InferenceClientFactory` plumbing it needs — until then the only
36// in-tree impl is the heuristic one below.
37
38/// Deterministic, LLM-free intent classifier.
39///
40/// Classifies a request as [`ResearchTier::Shallow`] when **all** of
41/// the following hold; otherwise it routes to [`ResearchTier::Deep`]:
42///
43/// 1. `req.query.chars().count()` is strictly less than
44///    [`HeuristicIntentClassifier::max_shallow_query_chars`].
45/// 2. The query contains zero or one `?` characters (multiple sub-
46///    questions imply deep).
47/// 3. `req.depth` is `<= ` [`HeuristicIntentClassifier::max_shallow_depth`].
48/// 4. The query contains none of the configured comparative markers
49///    (case-insensitive substring match).
50///
51/// All thresholds are tunable via the `with_*` builder methods.
52#[derive(Debug, Clone)]
53pub struct HeuristicIntentClassifier {
54    /// Strict upper bound on `chars().count()` of a shallow query.
55    pub max_shallow_query_chars: usize,
56    /// Maximum allowed `?` count in a shallow query.
57    pub max_shallow_question_marks: usize,
58    /// Maximum allowed `req.depth` for a shallow query.
59    pub max_shallow_depth: u32,
60    /// Case-insensitive substrings whose presence forces the request
61    /// to the deep tier.
62    pub comparative_markers: Vec<String>,
63}
64
65impl Default for HeuristicIntentClassifier {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl HeuristicIntentClassifier {
72    /// Build a heuristic classifier with the documented defaults.
73    pub fn new() -> Self {
74        Self {
75            max_shallow_query_chars: 80,
76            max_shallow_question_marks: 1,
77            max_shallow_depth: 1,
78            comparative_markers: default_comparative_markers(),
79        }
80    }
81
82    /// Override the strict character-count upper bound.
83    pub fn with_max_shallow_query_chars(mut self, n: usize) -> Self {
84        self.max_shallow_query_chars = n;
85        self
86    }
87
88    /// Override the maximum allowed `?` count for shallow queries.
89    pub fn with_max_shallow_question_marks(mut self, n: usize) -> Self {
90        self.max_shallow_question_marks = n;
91        self
92    }
93
94    /// Override the maximum allowed `req.depth` for shallow queries.
95    pub fn with_max_shallow_depth(mut self, n: u32) -> Self {
96        self.max_shallow_depth = n;
97        self
98    }
99
100    /// Replace the comparative-marker list outright.
101    pub fn with_comparative_markers<I, S>(mut self, markers: I) -> Self
102    where
103        I: IntoIterator<Item = S>,
104        S: Into<String>,
105    {
106        self.comparative_markers = markers.into_iter().map(Into::into).collect();
107        self
108    }
109
110    /// Pure synchronous classifier used by the `async` trait method.
111    /// Public so callers and tests can exercise it without `await`.
112    pub fn classify_sync(&self, req: &ResearchRequest) -> ResearchTier {
113        let query = req.query.as_str();
114
115        // Rule 1: query length.
116        let char_count = query.chars().count();
117        if char_count >= self.max_shallow_query_chars {
118            return ResearchTier::Deep;
119        }
120
121        // Rule 2: '?' count.
122        let qm = query.chars().filter(|c| *c == '?').count();
123        if qm > self.max_shallow_question_marks {
124            return ResearchTier::Deep;
125        }
126
127        // Rule 3: depth.
128        if req.depth > self.max_shallow_depth {
129            return ResearchTier::Deep;
130        }
131
132        // Rule 4: comparative markers (case-insensitive).
133        let lowered = query.to_lowercase();
134        for marker in &self.comparative_markers {
135            if lowered.contains(&marker.to_lowercase()) {
136                return ResearchTier::Deep;
137            }
138        }
139
140        ResearchTier::Shallow
141    }
142}
143
144#[async_trait]
145impl IntentClassifier for HeuristicIntentClassifier {
146    async fn classify(&self, req: &ResearchRequest) -> Result<ResearchTier> {
147        Ok(self.classify_sync(req))
148    }
149}
150
151/// Canonical default comparative-marker list.
152///
153/// Each marker is treated as a case-insensitive substring of the query.
154/// Some markers include surrounding whitespace deliberately to avoid
155/// false positives (e.g. `" vs "` matches `"tokio vs async-std"` but
156/// not `"oversight"`).
157fn default_comparative_markers() -> Vec<String> {
158    vec![
159        "compare".into(),
160        "versus".into(),
161        " vs ".into(),
162        " vs.".into(),
163        "trade-off".into(),
164        "tradeoff".into(),
165        "analyze".into(),
166        "deep dive".into(),
167        "research".into(),
168        "contrast".into(),
169        "differences between".into(),
170        // The multi-entity comparative shape — narrower than a bare
171        // "how" which would be far too aggressive.
172        "how do ".into(),
173    ]
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use atomr_agents_deep_research_core::ResearchRequest;
180
181    #[test]
182    fn defaults_route_short_queries_shallow() {
183        let c = HeuristicIntentClassifier::new();
184        let req = ResearchRequest::new("rust").with_depth(1);
185        assert_eq!(c.classify_sync(&req), ResearchTier::Shallow);
186    }
187
188    #[test]
189    fn comparative_markers_force_deep() {
190        let c = HeuristicIntentClassifier::new();
191        let req = ResearchRequest::new("compare actor frameworks");
192        assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
193    }
194
195    #[test]
196    fn depth_above_threshold_forces_deep() {
197        let c = HeuristicIntentClassifier::new();
198        let req = ResearchRequest::new("rust").with_depth(3);
199        assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
200    }
201
202    #[test]
203    fn long_queries_force_deep() {
204        let c = HeuristicIntentClassifier::new();
205        let long = "a".repeat(120);
206        let req = ResearchRequest::new(long).with_depth(0);
207        assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
208    }
209
210    #[test]
211    fn multiple_question_marks_force_deep() {
212        let c = HeuristicIntentClassifier::new();
213        let req = ResearchRequest::new("what? when? where?").with_depth(0);
214        assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
215    }
216
217    #[test]
218    fn builders_override_thresholds() {
219        let c = HeuristicIntentClassifier::new()
220            .with_max_shallow_query_chars(10)
221            .with_max_shallow_depth(0)
222            .with_comparative_markers(Vec::<String>::new());
223        let req = ResearchRequest::new("hello world").with_depth(0);
224        // 11 chars >= 10 → deep.
225        assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
226
227        let short = ResearchRequest::new("hi").with_depth(0);
228        assert_eq!(c.classify_sync(&short), ResearchTier::Shallow);
229    }
230}