use crate::language::ChunkType;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryCategory {
IdentifierLookup,
Structural,
Behavioral,
Conceptual,
MultiStep,
Negation,
TypeFiltered,
CrossLanguage,
Unknown,
}
impl std::fmt::Display for QueryCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::IdentifierLookup => write!(f, "identifier_lookup"),
Self::Structural => write!(f, "structural"),
Self::Behavioral => write!(f, "behavioral"),
Self::Conceptual => write!(f, "conceptual"),
Self::MultiStep => write!(f, "multi_step"),
Self::Negation => write!(f, "negation"),
Self::TypeFiltered => write!(f, "type_filtered"),
Self::CrossLanguage => write!(f, "cross_language"),
Self::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Confidence {
High,
Medium,
Low,
}
impl std::fmt::Display for Confidence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::High => write!(f, "high"),
Self::Medium => write!(f, "medium"),
Self::Low => write!(f, "low"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchStrategy {
NameOnly,
DenseDefault,
DenseWithTypeHints,
DenseWithSplade,
}
impl std::fmt::Display for SearchStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NameOnly => write!(f, "name_only"),
Self::DenseDefault => write!(f, "dense"),
Self::DenseWithTypeHints => write!(f, "dense_type_hints"),
Self::DenseWithSplade => write!(f, "dense_splade"),
}
}
}
#[derive(Debug, Clone)]
pub struct Classification {
pub category: QueryCategory,
pub confidence: Confidence,
pub strategy: SearchStrategy,
pub type_hints: Option<Vec<ChunkType>>,
}
const NL_INDICATORS: &[&str] = &[
"the",
"a",
"an",
"that",
"which",
"how",
"what",
"where",
"when",
"find",
"get",
"all",
"every",
"each",
"with",
"without",
"for",
"from",
"into",
"this",
"does",
"code",
"function",
"method",
"implement",
"using",
];
const BEHAVIORAL_VERBS: &[&str] = &[
"validates",
"processes",
"handles",
"manages",
"computes",
"parses",
"converts",
"transforms",
"filters",
"sorts",
"checks",
"verifies",
"sends",
"receives",
"reads",
"writes",
"creates",
"deletes",
"updates",
"serializes",
"deserializes",
"encodes",
"decodes",
"authenticates",
"authorizes",
"logs",
"retries",
"caches",
"renders",
];
const CONCEPTUAL_NOUNS: &[&str] = &[
"pattern",
"architecture",
"design",
"approach",
"strategy",
"algorithm",
"principle",
"abstraction",
"convention",
"idiom",
"paradigm",
"concept",
"technique",
"methodology",
];
const NEGATION_WORDS: &[&str] = &[
"not ",
"without ",
"except ",
"never ",
"avoid ",
"no ",
"don't ",
"doesn't ",
"shouldn't ",
"exclude ",
];
const STRUCTURAL_KEYWORDS: &[&str] = &[
"struct",
"enum",
"trait",
"impl",
"interface",
"class",
"module",
"namespace",
"protocol",
"type",
];
const LANGUAGE_NAMES: &[&str] = &[
"python",
"rust",
"javascript",
"typescript",
"java",
"go",
"ruby",
"c++",
"cpp",
"csharp",
"c#",
"swift",
"kotlin",
"scala",
"elixir",
"haskell",
"ocaml",
"php",
"perl",
"lua",
"sql",
];
const STRUCTURAL_PATTERNS: &[&str] = &[
"functions that",
"methods that",
"types that",
"structs that",
"that return",
"that take",
"that accept",
"with signature",
"implementing",
"extending",
"deriving",
];
const MULTISTEP_PATTERNS: &[&str] = &[
"and then", "before ", "after ", " and ", " or ", "first ", "then ", "both ", "between ",
];
pub fn classify_query(query: &str) -> Classification {
let query_lower = query.to_lowercase();
let words: Vec<&str> = query_lower.split_whitespace().collect();
if words.is_empty() {
return Classification {
category: QueryCategory::Unknown,
confidence: Confidence::Low,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
if NEGATION_WORDS.iter().any(|w| query_lower.contains(w)) {
return Classification {
category: QueryCategory::Negation,
confidence: Confidence::High,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
if is_identifier_query(&query_lower, &words) {
return Classification {
category: QueryCategory::IdentifierLookup,
confidence: Confidence::High,
strategy: SearchStrategy::NameOnly,
type_hints: None,
};
}
if is_cross_language_query(&query_lower, &words) {
return Classification {
category: QueryCategory::CrossLanguage,
confidence: Confidence::High,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
let type_hints = extract_type_hints(&query_lower);
if type_hints.is_some() {
return Classification {
category: QueryCategory::TypeFiltered,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseWithTypeHints,
type_hints,
};
}
if is_structural_query(&query_lower) {
return Classification {
category: QueryCategory::Structural,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseWithTypeHints,
type_hints: None,
};
}
if is_behavioral_query(&query_lower, &words) {
return Classification {
category: QueryCategory::Behavioral,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
if is_conceptual_query(&query_lower, &words) {
return Classification {
category: QueryCategory::Conceptual,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
if MULTISTEP_PATTERNS.iter().any(|p| query_lower.contains(p)) {
return Classification {
category: QueryCategory::MultiStep,
confidence: Confidence::Low,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
Classification {
category: QueryCategory::Unknown,
confidence: Confidence::Low,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
}
}
fn is_identifier_query(_query: &str, words: &[&str]) -> bool {
if words.len() == 1 {
let w = words[0];
if !w.chars().any(|c| c.is_alphabetic()) {
return false;
}
if NL_INDICATORS.contains(&w) {
return false;
}
return w
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == ':' || c == '.' || c == '/');
}
if words.len() <= 3 {
let has_nl = words.iter().any(|w| NL_INDICATORS.contains(w));
if has_nl {
return false;
}
let has_identifier_signal = words.iter().any(|w| {
w.contains('_')
|| w.contains("::")
|| w.contains('.')
|| (w.chars().any(|c| c.is_uppercase()) && w.chars().any(|c| c.is_lowercase()))
});
let all_identifier_chars = words.iter().all(|w| {
w.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == ':' || c == '.')
});
return has_identifier_signal && all_identifier_chars;
}
false
}
fn is_cross_language_query(query: &str, words: &[&str]) -> bool {
let lang_count = LANGUAGE_NAMES
.iter()
.filter(|l| words.iter().any(|w| *w == **l))
.count();
if lang_count >= 2 {
return true;
}
if lang_count >= 1
&& (query.contains("equivalent")
|| query.contains("translate")
|| query.contains("port ")
|| query.contains("convert "))
{
return true;
}
false
}
fn is_structural_query(query: &str) -> bool {
if STRUCTURAL_PATTERNS.iter().any(|p| query.contains(p)) {
return true;
}
STRUCTURAL_KEYWORDS
.iter()
.any(|kw| query.contains(&format!(" {} ", kw)) || query.starts_with(&format!("{} ", kw)))
}
fn is_behavioral_query(query: &str, words: &[&str]) -> bool {
if words.iter().any(|w| BEHAVIORAL_VERBS.contains(w)) {
return true;
}
query.contains("how does")
|| query.contains("what does")
|| query.contains("code that")
|| query.contains("function that")
}
fn is_conceptual_query(query: &str, words: &[&str]) -> bool {
if words.iter().any(|w| CONCEPTUAL_NOUNS.contains(w)) {
return true;
}
words.len() <= 3
&& words.iter().any(|w| NL_INDICATORS.contains(w))
&& !is_structural_query(query)
}
pub fn extract_type_hints(query: &str) -> Option<Vec<ChunkType>> {
let mut types = Vec::new();
let patterns: &[(&str, ChunkType)] = &[
("test function", ChunkType::Test),
("test method", ChunkType::Test),
("all tests", ChunkType::Test),
("every test", ChunkType::Test),
("all structs", ChunkType::Struct),
("every struct", ChunkType::Struct),
("all enums", ChunkType::Enum),
("every enum", ChunkType::Enum),
("all traits", ChunkType::Trait),
("every trait", ChunkType::Trait),
("all interfaces", ChunkType::Interface),
("every interface", ChunkType::Interface),
("all classes", ChunkType::Class),
("every class", ChunkType::Class),
("endpoint", ChunkType::Endpoint),
("all constants", ChunkType::Constant),
("all modules", ChunkType::Module),
];
for (pattern, chunk_type) in patterns {
if query.contains(pattern) {
types.push(*chunk_type);
}
}
if types.is_empty() {
None
} else {
Some(types)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_identifier_snake_case() {
let c = classify_query("search_filtered");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
assert_eq!(c.confidence, Confidence::High);
assert_eq!(c.strategy, SearchStrategy::NameOnly);
}
#[test]
fn test_classify_identifier_qualified() {
let c = classify_query("HashMap::new");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_identifier_camel() {
let c = classify_query("SearchFilter");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_behavioral() {
let c = classify_query("validates user input");
assert_eq!(c.category, QueryCategory::Behavioral);
assert_eq!(c.confidence, Confidence::Medium);
assert_eq!(c.strategy, SearchStrategy::DenseDefault);
}
#[test]
fn test_classify_negation() {
let c = classify_query("sort without allocating");
assert_eq!(c.category, QueryCategory::Negation);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_structural() {
let c = classify_query("functions that return Result");
assert_eq!(c.category, QueryCategory::Structural);
assert_eq!(c.confidence, Confidence::Medium);
}
#[test]
fn test_classify_type_filtered() {
let c = classify_query("all test functions");
assert_eq!(c.category, QueryCategory::TypeFiltered);
assert!(c.type_hints.is_some());
assert!(c.type_hints.unwrap().contains(&ChunkType::Test));
}
#[test]
fn test_classify_cross_language() {
let c = classify_query("Python equivalent of map in Rust");
assert_eq!(c.category, QueryCategory::CrossLanguage);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_conceptual() {
let c = classify_query("dependency injection pattern");
assert_eq!(c.category, QueryCategory::Conceptual);
assert_eq!(c.confidence, Confidence::Medium);
}
#[test]
fn test_classify_multi_step() {
let c = classify_query("find errors and then retry them");
assert_eq!(c.category, QueryCategory::MultiStep);
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_classify_unknown() {
let c = classify_query("asdf jkl qwerty");
assert_eq!(c.category, QueryCategory::Unknown);
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_extract_type_hints_struct() {
let hints = extract_type_hints("find all structs");
assert!(hints.is_some());
assert!(hints.unwrap().contains(&ChunkType::Struct));
}
#[test]
fn test_extract_type_hints_none() {
let hints = extract_type_hints("handle errors gracefully");
assert!(hints.is_none());
}
#[test]
fn test_classify_empty() {
let c = classify_query("");
assert_eq!(c.category, QueryCategory::Unknown);
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_classify_single_char() {
let c = classify_query("a");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_very_long() {
let long = "a ".repeat(5000);
let start = std::time::Instant::now();
let c = classify_query(&long);
let elapsed = start.elapsed();
assert!(elapsed.as_millis() < 100, "Should complete in <100ms");
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_classify_unicode_identifier() {
let c = classify_query("日本語_関数");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_path_like() {
let c = classify_query("src/store/mod.rs");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_only_stopwords() {
let c = classify_query("the a an of");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_special_chars() {
let c = classify_query("fn<T: Hash>()");
let _ = c;
}
#[test]
fn test_classify_all_caps() {
let c = classify_query("WHERE IS THE ERROR HANDLER");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_numbers() {
let c = classify_query("404");
assert_eq!(c.category, QueryCategory::Unknown);
}
#[test]
fn test_classify_hex() {
let c = classify_query("0xFF");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_mixed_signals() {
let c = classify_query("not struct");
assert_eq!(c.category, QueryCategory::Negation);
}
#[test]
fn test_classify_sql_injection() {
let c = classify_query("'; DROP TABLE--");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_null_bytes() {
let c = classify_query("foo\0bar");
let _ = c;
}
#[test]
fn test_classify_type_hint_wrong_extraction() {
let hints = extract_type_hints("error handling");
assert!(hints.is_none());
}
#[test]
fn test_classify_identifier_common_word() {
let c = classify_query("error");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
}