use async_trait::async_trait;
use atomr_agents_deep_research_core::ResearchRequest;
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ResearchTier {
Shallow,
Deep,
}
#[async_trait]
pub trait IntentClassifier: Send + Sync + 'static {
async fn classify(&self, req: &ResearchRequest) -> Result<ResearchTier>;
}
#[derive(Debug, Clone)]
pub struct HeuristicIntentClassifier {
pub max_shallow_query_chars: usize,
pub max_shallow_question_marks: usize,
pub max_shallow_depth: u32,
pub comparative_markers: Vec<String>,
}
impl Default for HeuristicIntentClassifier {
fn default() -> Self {
Self::new()
}
}
impl HeuristicIntentClassifier {
pub fn new() -> Self {
Self {
max_shallow_query_chars: 80,
max_shallow_question_marks: 1,
max_shallow_depth: 1,
comparative_markers: default_comparative_markers(),
}
}
pub fn with_max_shallow_query_chars(mut self, n: usize) -> Self {
self.max_shallow_query_chars = n;
self
}
pub fn with_max_shallow_question_marks(mut self, n: usize) -> Self {
self.max_shallow_question_marks = n;
self
}
pub fn with_max_shallow_depth(mut self, n: u32) -> Self {
self.max_shallow_depth = n;
self
}
pub fn with_comparative_markers<I, S>(mut self, markers: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.comparative_markers = markers.into_iter().map(Into::into).collect();
self
}
pub fn classify_sync(&self, req: &ResearchRequest) -> ResearchTier {
let query = req.query.as_str();
let char_count = query.chars().count();
if char_count >= self.max_shallow_query_chars {
return ResearchTier::Deep;
}
let qm = query.chars().filter(|c| *c == '?').count();
if qm > self.max_shallow_question_marks {
return ResearchTier::Deep;
}
if req.depth > self.max_shallow_depth {
return ResearchTier::Deep;
}
let lowered = query.to_lowercase();
for marker in &self.comparative_markers {
if lowered.contains(&marker.to_lowercase()) {
return ResearchTier::Deep;
}
}
ResearchTier::Shallow
}
}
#[async_trait]
impl IntentClassifier for HeuristicIntentClassifier {
async fn classify(&self, req: &ResearchRequest) -> Result<ResearchTier> {
Ok(self.classify_sync(req))
}
}
fn default_comparative_markers() -> Vec<String> {
vec![
"compare".into(),
"versus".into(),
" vs ".into(),
" vs.".into(),
"trade-off".into(),
"tradeoff".into(),
"analyze".into(),
"deep dive".into(),
"research".into(),
"contrast".into(),
"differences between".into(),
"how do ".into(),
]
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_agents_deep_research_core::ResearchRequest;
#[test]
fn defaults_route_short_queries_shallow() {
let c = HeuristicIntentClassifier::new();
let req = ResearchRequest::new("rust").with_depth(1);
assert_eq!(c.classify_sync(&req), ResearchTier::Shallow);
}
#[test]
fn comparative_markers_force_deep() {
let c = HeuristicIntentClassifier::new();
let req = ResearchRequest::new("compare actor frameworks");
assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
}
#[test]
fn depth_above_threshold_forces_deep() {
let c = HeuristicIntentClassifier::new();
let req = ResearchRequest::new("rust").with_depth(3);
assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
}
#[test]
fn long_queries_force_deep() {
let c = HeuristicIntentClassifier::new();
let long = "a".repeat(120);
let req = ResearchRequest::new(long).with_depth(0);
assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
}
#[test]
fn multiple_question_marks_force_deep() {
let c = HeuristicIntentClassifier::new();
let req = ResearchRequest::new("what? when? where?").with_depth(0);
assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
}
#[test]
fn builders_override_thresholds() {
let c = HeuristicIntentClassifier::new()
.with_max_shallow_query_chars(10)
.with_max_shallow_depth(0)
.with_comparative_markers(Vec::<String>::new());
let req = ResearchRequest::new("hello world").with_depth(0);
assert_eq!(c.classify_sync(&req), ResearchTier::Deep);
let short = ResearchRequest::new("hi").with_depth(0);
assert_eq!(c.classify_sync(&short), ResearchTier::Shallow);
}
}