use hirn_query::ast::RetrievalMode;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryComplexity {
Simple,
Moderate,
Complex,
}
pub fn classify_and_route(
query: &str,
involving_count: usize,
where_count: usize,
has_temporal: bool,
has_expand: bool,
has_follow_causes: bool,
) -> RetrievalMode {
let complexity = classify_query(
query,
involving_count,
where_count,
has_temporal,
has_expand,
has_follow_causes,
);
match complexity {
QueryComplexity::Simple => RetrievalMode::Local,
QueryComplexity::Moderate => RetrievalMode::Hybrid,
QueryComplexity::Complex => RetrievalMode::Raptor,
}
}
pub fn classify_query(
query: &str,
involving_count: usize,
where_count: usize,
has_temporal: bool,
has_expand: bool,
has_follow_causes: bool,
) -> QueryComplexity {
let mut score: u32 = 0;
let token_count = query.split_whitespace().count();
if token_count >= 20 {
score += 3;
} else if token_count >= 10 {
score += 2;
} else if token_count >= 4 {
score += 1;
}
score += (where_count as u32).min(3);
if involving_count > 2 {
score += 2;
} else if involving_count > 0 {
score += 1;
}
let lower = query.to_lowercase();
let complex_patterns = [
"compare",
"contrast",
"why",
"how does",
"what caused",
"relationship between",
"difference between",
"trade-off",
"pros and cons",
"implications of",
"summarize all",
"overview of",
"explain the",
"analyze",
];
let moderate_patterns = [
"how", "what are", "describe", "list", "when did", "where", "who", "which",
];
let complex_hits = complex_patterns
.iter()
.filter(|p| lower.contains(*p))
.count();
let moderate_hits = moderate_patterns
.iter()
.filter(|p| lower.contains(*p))
.count();
score += (complex_hits as u32) * 2;
score += (moderate_hits as u32).min(2);
if has_temporal {
score += 2;
}
if has_expand {
score += 3;
}
if has_follow_causes {
score += 3;
}
if score >= 6 {
QueryComplexity::Complex
} else if score >= 3 {
QueryComplexity::Moderate
} else {
QueryComplexity::Simple
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_factoid_query() {
let c = classify_query("what is JWT", 0, 0, false, false, false);
assert_eq!(c, QueryComplexity::Simple);
}
#[test]
fn moderate_query_with_entity() {
let c = classify_query(
"how does authentication work with OAuth tokens",
1,
0,
false,
false,
false,
);
assert_eq!(c, QueryComplexity::Moderate);
}
#[test]
fn complex_analytical_query() {
let c = classify_query(
"compare the trade-off between JWT and session-based authentication across all services",
3,
1,
false,
true,
false,
);
assert_eq!(c, QueryComplexity::Complex);
}
#[test]
fn temporal_adds_complexity() {
let c = classify_query("what happened with deployments", 0, 0, true, false, false);
assert_eq!(c, QueryComplexity::Moderate);
}
#[test]
fn follow_causes_is_complex() {
let c = classify_query("why did the service fail", 0, 0, false, false, true);
assert_eq!(c, QueryComplexity::Complex);
}
#[test]
fn classify_and_route_simple() {
let mode = classify_and_route("hello", 0, 0, false, false, false);
assert_eq!(mode, RetrievalMode::Local);
}
#[test]
fn classify_and_route_complex() {
let mode = classify_and_route(
"compare all authentication strategies and their trade-offs",
2,
1,
true,
true,
false,
);
assert_eq!(mode, RetrievalMode::Raptor);
}
}