use regex::Regex;
use std::sync::OnceLock;
use super::intent::QueryIntent;
pub struct QueryClassifier;
static DEFINITION_RE: OnceLock<Regex> = OnceLock::new();
static USAGE_RE: OnceLock<Regex> = OnceLock::new();
static CONCEPTUAL_RE: OnceLock<Regex> = OnceLock::new();
static BUG_DEBT_RE: OnceLock<Regex> = OnceLock::new();
static ENTITY_DEF_RE: OnceLock<Regex> = OnceLock::new();
static ENTITY_USAGE_RE: OnceLock<Regex> = OnceLock::new();
static ENTITY_BUG_RE: OnceLock<Regex> = OnceLock::new();
static DOMAIN_DEF_RE: OnceLock<Regex> = OnceLock::new();
static EXTENDED_BUG_RE: OnceLock<Regex> = OnceLock::new();
static LONG_NL_RE: OnceLock<Regex> = OnceLock::new();
static PASCAL_IDENT_RE: OnceLock<Regex> = OnceLock::new();
static SNAKE_IDENT_RE: OnceLock<Regex> = OnceLock::new();
static ACRONYM_HINT_RE: OnceLock<Regex> = OnceLock::new();
static MULTI_NOUN_RE: OnceLock<Regex> = OnceLock::new();
static SCREAM_IDENT_RE: OnceLock<Regex> = OnceLock::new();
impl QueryClassifier {
pub fn classify(query: &str) -> QueryIntent {
let def_re = DEFINITION_RE.get_or_init(|| {
Regex::new(
r"(?i)\b(fn |struct |impl |trait |enum |type |def |class |function |define)\b",
)
.expect("static regex pattern must compile")
});
let usage_re = USAGE_RE.get_or_init(|| {
Regex::new(r"(?i)\b(where is|callers of|who calls|uses of|usages|called by)\b")
.expect("static regex pattern must compile")
});
let conceptual_re = CONCEPTUAL_RE.get_or_init(|| {
Regex::new(r"(?i)\b(how does|what is|explain|overview|architecture|design|why)\b")
.expect("static regex pattern must compile")
});
let bug_re = BUG_DEBT_RE.get_or_init(|| {
Regex::new(r"(?i)\b(TODO|FIXME|HACK|panic!|unwrap\(\)|bug|error|crash|fail)\b")
.expect("static regex pattern must compile")
});
let entity_def_re = ENTITY_DEF_RE.get_or_init(|| {
Regex::new(r"(?i)\b(implements|derives from|aliased as)\b")
.expect("static regex pattern must compile")
});
let entity_usage_re = ENTITY_USAGE_RE.get_or_init(|| {
Regex::new(r"(?i)\b(tested by|co-occurs)\b").expect("static regex pattern must compile")
});
let entity_bug_re = ENTITY_BUG_RE.get_or_init(|| {
Regex::new(r"(?i)\b(raises|documented by)\b")
.expect("static regex pattern must compile")
});
let domain_def_re = DOMAIN_DEF_RE.get_or_init(|| {
Regex::new(
r"(?x)
# Pattern A — standalone structural vocabulary word
(?i)\b(definition|interface|schema|model|enum)\b
|
# Pattern B — PascalCase identifier + structural keyword
\b[A-Z][a-zA-Z0-9]+\s+(?i)(definition|struct|class|interface|type|schema|enum|trait|model)\b
",
)
.expect("static regex pattern must compile")
});
let extended_bug_re = EXTENDED_BUG_RE.get_or_init(|| {
Regex::new(
r"(?i)\b(error\s+handling|deprecated|legacy|missing\s+validation|hardcoded)\b",
)
.expect("static regex pattern must compile")
});
let long_nl_re = LONG_NL_RE.get_or_init(|| {
Regex::new(r"^[^():_.]+(?:\s+[^():_.]+){5,}$")
.expect("static regex pattern must compile")
});
if entity_usage_re.is_match(query) {
return QueryIntent::Usage;
}
if entity_def_re.is_match(query) {
return QueryIntent::Definition;
}
if entity_bug_re.is_match(query) {
return QueryIntent::BugDebt;
}
if domain_def_re.is_match(query) {
return QueryIntent::Definition;
}
if extended_bug_re.is_match(query) {
return QueryIntent::BugDebt;
}
if usage_re.is_match(query) {
return QueryIntent::Usage;
}
if def_re.is_match(query) {
return QueryIntent::Definition;
}
if conceptual_re.is_match(query) {
return QueryIntent::Conceptual;
}
if bug_re.is_match(query) {
return QueryIntent::BugDebt;
}
let trimmed = query.trim();
if long_nl_re.is_match(trimmed) {
return QueryIntent::Conceptual;
}
let pascal_ident_re = PASCAL_IDENT_RE.get_or_init(|| {
Regex::new(
r"\b(?:[A-Z][a-z]+[A-Z][a-zA-Z0-9]*|[A-Z]{2,}(?:[0-9]+[A-Za-z][a-zA-Z0-9]*|[0-9]+|[A-Z][a-z][a-zA-Z0-9]*))\b",
)
.expect("static regex pattern must compile")
});
if pascal_ident_re.is_match(trimmed) {
return QueryIntent::Definition;
}
let snake_ident_re = SNAKE_IDENT_RE.get_or_init(|| {
Regex::new(r"^[a-z][a-z0-9_]*_[a-z0-9_]+$").expect("static regex pattern must compile")
});
if snake_ident_re.is_match(trimmed) {
return QueryIntent::Definition;
}
let scream_ident_re = SCREAM_IDENT_RE.get_or_init(|| {
Regex::new(r"^[A-Z][A-Z0-9]*(?:_[A-Z0-9]+)+$")
.expect("static regex pattern must compile")
});
if scream_ident_re.is_match(trimmed) {
return QueryIntent::Definition;
}
let acronym_hint_re = ACRONYM_HINT_RE.get_or_init(|| {
Regex::new(r"\b[A-Z]{2,}[0-9]*\b").expect("static regex pattern must compile")
});
if acronym_hint_re.is_match(trimmed) {
let token_count = trimmed.split_whitespace().count();
let has_nl_words = trimmed
.split_whitespace()
.any(|t| t.chars().next().is_some_and(|c| c.is_lowercase()));
if token_count <= 2 || !has_nl_words {
return QueryIntent::Definition;
}
}
let multi_noun_re = MULTI_NOUN_RE.get_or_init(|| {
Regex::new(r"^[A-Za-z0-9]+(?:\s+[A-Za-z0-9]+){3,}$")
.expect("static regex pattern must compile")
});
if multi_noun_re.is_match(trimmed) {
return QueryIntent::Conceptual;
}
QueryIntent::Unknown
}
pub fn classify_with_domain(query: &str, domain_terms: &[String]) -> QueryIntent {
let base = Self::classify(query);
if base != QueryIntent::Unknown {
return base;
}
if domain_terms.is_empty() {
return base;
}
let q = query.to_lowercase();
for term in domain_terms {
let t = term.trim();
if t.is_empty() {
continue;
}
if q.contains(&t.to_lowercase()) {
return QueryIntent::Definition;
}
}
base
}
}