use crate::language::{ChunkType, REGISTRY};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryCategory {
IdentifierLookup,
Structural,
Behavioral,
Conceptual,
MultiStep,
Negation,
TypeFiltered,
CrossLanguage,
Unknown,
}
impl std::fmt::Display for QueryCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::IdentifierLookup => write!(f, "identifier_lookup"),
Self::Structural => write!(f, "structural"),
Self::Behavioral => write!(f, "behavioral"),
Self::Conceptual => write!(f, "conceptual"),
Self::MultiStep => write!(f, "multi_step"),
Self::Negation => write!(f, "negation"),
Self::TypeFiltered => write!(f, "type_filtered"),
Self::CrossLanguage => write!(f, "cross_language"),
Self::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Confidence {
High,
Medium,
Low,
}
impl std::fmt::Display for Confidence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::High => write!(f, "high"),
Self::Medium => write!(f, "medium"),
Self::Low => write!(f, "low"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchStrategy {
NameOnly,
DenseDefault,
DenseWithTypeHints,
DenseBase,
}
impl std::fmt::Display for SearchStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NameOnly => write!(f, "name_only"),
Self::DenseDefault => write!(f, "dense"),
Self::DenseWithTypeHints => write!(f, "dense_type_hints"),
Self::DenseBase => write!(f, "dense_base"),
}
}
}
#[derive(Debug, Clone)]
pub struct Classification {
pub category: QueryCategory,
pub confidence: Confidence,
pub strategy: SearchStrategy,
pub type_hints: Option<Vec<ChunkType>>,
}
const NL_INDICATORS: &[&str] = &[
"the",
"a",
"an",
"that",
"which",
"how",
"what",
"where",
"when",
"find",
"get",
"all",
"every",
"each",
"with",
"without",
"for",
"from",
"into",
"this",
"does",
"code",
"function",
"method",
"implement",
"using",
];
const BEHAVIORAL_VERBS: &[&str] = &[
"validates",
"processes",
"handles",
"manages",
"computes",
"parses",
"converts",
"transforms",
"filters",
"sorts",
"checks",
"verifies",
"sends",
"receives",
"reads",
"writes",
"creates",
"deletes",
"updates",
"serializes",
"deserializes",
"encodes",
"decodes",
"authenticates",
"authorizes",
"logs",
"retries",
"caches",
"renders",
];
const CONCEPTUAL_NOUNS: &[&str] = &[
"pattern",
"architecture",
"design",
"approach",
"strategy",
"algorithm",
"principle",
"abstraction",
"convention",
"idiom",
"paradigm",
"concept",
"technique",
"methodology",
];
const NEGATION_TOKENS: &[&str] = &[
"not",
"without",
"except",
"never",
"avoid",
"no",
"don't",
"doesn't",
"shouldn't",
"exclude",
];
const STRUCTURAL_KEYWORDS: &[&str] = &[
"struct",
"enum",
"trait",
"impl",
"interface",
"class",
"module",
"namespace",
"protocol",
"type",
];
const LANGUAGE_ALIASES: &[&str] = &["c++", "c#"];
fn language_names() -> Vec<&'static str> {
let mut names: Vec<&'static str> = REGISTRY.all().map(|def| def.name).collect();
for alias in LANGUAGE_ALIASES {
if !names.contains(alias) {
names.push(alias);
}
}
names
}
const STRUCTURAL_PATTERNS: &[&str] = &[
"functions that",
"methods that",
"types that",
"structs that",
"that return",
"that take",
"that accept",
"with signature",
"implementing",
"extending",
"deriving",
];
const MULTISTEP_PATTERNS: &[&str] = &[
"and then", "before ", "after ", " and ", " or ", "first ", "then ", "both ", "between ",
];
pub fn resolve_splade_alpha(category: &QueryCategory) -> f32 {
let cat_key = format!("CQS_SPLADE_ALPHA_{}", category.to_string().to_uppercase());
if let Ok(val) = std::env::var(&cat_key) {
if let Ok(alpha) = val.parse::<f32>() {
if alpha.is_finite() {
return alpha.clamp(0.0, 1.0);
}
tracing::warn!(var = %cat_key, value = %val, "Non-finite alpha, using default");
} else {
tracing::warn!(var = %cat_key, value = %val, "Invalid alpha, using default");
}
}
if let Ok(val) = std::env::var("CQS_SPLADE_ALPHA") {
if let Ok(alpha) = val.parse::<f32>() {
if alpha.is_finite() {
return alpha.clamp(0.0, 1.0);
}
}
}
match category {
QueryCategory::IdentifierLookup => 0.90,
QueryCategory::Structural => 0.60,
QueryCategory::Conceptual => 0.85,
QueryCategory::Behavioral => 0.05,
_ => 1.0,
}
}
pub fn classify_query(query: &str) -> Classification {
let query_lower = query.to_lowercase();
let words: Vec<&str> = query_lower.split_whitespace().collect();
if words.is_empty() {
return Classification {
category: QueryCategory::Unknown,
confidence: Confidence::Low,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
if words.iter().any(|w| NEGATION_TOKENS.contains(w)) {
return Classification {
category: QueryCategory::Negation,
confidence: Confidence::High,
strategy: SearchStrategy::DenseBase,
type_hints: None,
};
}
if is_identifier_query(&query_lower, &words) {
return Classification {
category: QueryCategory::IdentifierLookup,
confidence: Confidence::High,
strategy: SearchStrategy::NameOnly,
type_hints: None,
};
}
if is_cross_language_query(&query_lower, &words) {
return Classification {
category: QueryCategory::CrossLanguage,
confidence: Confidence::High,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
let type_hints = extract_type_hints(&query_lower);
if type_hints.is_some() {
return Classification {
category: QueryCategory::TypeFiltered,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseBase,
type_hints,
};
}
if is_structural_query(&query_lower) {
return Classification {
category: QueryCategory::Structural,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseWithTypeHints,
type_hints: None,
};
}
if is_behavioral_query(&query_lower, &words) {
return Classification {
category: QueryCategory::Behavioral,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseBase,
type_hints: None,
};
}
if is_conceptual_query(&query_lower, &words) {
return Classification {
category: QueryCategory::Conceptual,
confidence: Confidence::Medium,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
};
}
if MULTISTEP_PATTERNS.iter().any(|p| query_lower.contains(p)) {
return Classification {
category: QueryCategory::MultiStep,
confidence: Confidence::Low,
strategy: SearchStrategy::DenseBase,
type_hints: None,
};
}
Classification {
category: QueryCategory::Unknown,
confidence: Confidence::Low,
strategy: SearchStrategy::DenseDefault,
type_hints: None,
}
}
fn is_identifier_query(_query: &str, words: &[&str]) -> bool {
if words.len() == 1 {
let w = words[0];
if !w.chars().any(|c| c.is_alphabetic()) {
return false;
}
if NL_INDICATORS.contains(&w) {
return false;
}
return w
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == ':' || c == '.' || c == '/');
}
if words.len() <= 3 {
let has_nl = words.iter().any(|w| NL_INDICATORS.contains(w));
if has_nl {
return false;
}
let has_identifier_signal = words.iter().any(|w| {
w.contains('_')
|| w.contains("::")
|| w.contains('.')
|| (w.chars().any(|c| c.is_uppercase()) && w.chars().any(|c| c.is_lowercase()))
});
let all_identifier_chars = words.iter().all(|w| {
w.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == ':' || c == '.')
});
return has_identifier_signal && all_identifier_chars;
}
false
}
fn is_cross_language_query(query: &str, words: &[&str]) -> bool {
let names = language_names();
let lang_count = names
.iter()
.filter(|l| words.iter().any(|w| *w == **l))
.count();
if lang_count >= 2 {
return true;
}
if lang_count >= 1
&& (query.contains("equivalent")
|| query.contains("translate")
|| query.contains("port ")
|| query.contains("convert "))
{
return true;
}
false
}
fn is_structural_query(query: &str) -> bool {
if STRUCTURAL_PATTERNS.iter().any(|p| query.contains(p)) {
return true;
}
STRUCTURAL_KEYWORDS
.iter()
.any(|kw| query.contains(&format!(" {} ", kw)) || query.starts_with(&format!("{} ", kw)))
}
fn is_behavioral_query(query: &str, words: &[&str]) -> bool {
if words.iter().any(|w| BEHAVIORAL_VERBS.contains(w)) {
return true;
}
query.contains("code that") || query.contains("function that")
}
fn is_conceptual_query(query: &str, words: &[&str]) -> bool {
if words.iter().any(|w| CONCEPTUAL_NOUNS.contains(w)) {
return true;
}
words.len() <= 3
&& words.iter().any(|w| NL_INDICATORS.contains(w))
&& !is_structural_query(query)
}
pub fn extract_type_hints(query: &str) -> Option<Vec<ChunkType>> {
let mut types = Vec::new();
let patterns: &[(&str, ChunkType)] = &[
("test function", ChunkType::Test),
("test method", ChunkType::Test),
("all tests", ChunkType::Test),
("every test", ChunkType::Test),
("all functions", ChunkType::Function),
("every function", ChunkType::Function),
("all methods", ChunkType::Method),
("every method", ChunkType::Method),
("all structs", ChunkType::Struct),
("every struct", ChunkType::Struct),
("all enums", ChunkType::Enum),
("every enum", ChunkType::Enum),
("all traits", ChunkType::Trait),
("every trait", ChunkType::Trait),
("all interfaces", ChunkType::Interface),
("every interface", ChunkType::Interface),
("all classes", ChunkType::Class),
("every class", ChunkType::Class),
("type alias", ChunkType::TypeAlias),
("all type aliases", ChunkType::TypeAlias),
("all modules", ChunkType::Module),
("every module", ChunkType::Module),
("all objects", ChunkType::Object),
("every object", ChunkType::Object),
("all namespaces", ChunkType::Namespace),
("every namespace", ChunkType::Namespace),
("all impl blocks", ChunkType::Impl),
("implementation block", ChunkType::Impl),
("extension method", ChunkType::Extension),
("all extensions", ChunkType::Extension),
("all constants", ChunkType::Constant),
("every constant", ChunkType::Constant),
("all variables", ChunkType::Variable),
("every variable", ChunkType::Variable),
("all properties", ChunkType::Property),
("every property", ChunkType::Property),
("constructor", ChunkType::Constructor),
("all constructors", ChunkType::Constructor),
("all delegates", ChunkType::Delegate),
("every delegate", ChunkType::Delegate),
("all events", ChunkType::Event),
("every event", ChunkType::Event),
("all macros", ChunkType::Macro),
("every macro", ChunkType::Macro),
("macro_rules", ChunkType::Macro),
("endpoint", ChunkType::Endpoint),
("all endpoints", ChunkType::Endpoint),
("all services", ChunkType::Service),
("every service", ChunkType::Service),
("middleware", ChunkType::Middleware),
("all middleware", ChunkType::Middleware),
("stored procedure", ChunkType::StoredProc),
("all stored procedures", ChunkType::StoredProc),
("extern function", ChunkType::Extern),
("all externs", ChunkType::Extern),
("ffi declaration", ChunkType::Extern),
("config key", ChunkType::ConfigKey),
("all config keys", ChunkType::ConfigKey),
("all sections", ChunkType::Section),
("every section", ChunkType::Section),
("all modifiers", ChunkType::Modifier),
("every modifier", ChunkType::Modifier),
];
for (pattern, chunk_type) in patterns {
if query.contains(pattern) {
types.push(*chunk_type);
}
}
if types.is_empty() {
None
} else {
Some(types)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_identifier_snake_case() {
let c = classify_query("search_filtered");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
assert_eq!(c.confidence, Confidence::High);
assert_eq!(c.strategy, SearchStrategy::NameOnly);
}
#[test]
fn test_classify_identifier_qualified() {
let c = classify_query("HashMap::new");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_identifier_camel() {
let c = classify_query("SearchFilter");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_behavioral() {
let c = classify_query("validates user input");
assert_eq!(c.category, QueryCategory::Behavioral);
assert_eq!(c.confidence, Confidence::Medium);
assert_eq!(c.strategy, SearchStrategy::DenseBase);
}
#[test]
fn test_classify_negation() {
let c = classify_query("sort without allocating");
assert_eq!(c.category, QueryCategory::Negation);
assert_eq!(c.confidence, Confidence::High);
assert_eq!(c.strategy, SearchStrategy::DenseBase);
}
#[test]
fn test_classify_conceptual_routes_to_enriched() {
let c = classify_query("dependency injection pattern");
assert_eq!(c.category, QueryCategory::Conceptual);
assert_eq!(c.strategy, SearchStrategy::DenseDefault);
}
#[test]
fn test_classify_structural_stays_on_enriched() {
let c = classify_query("functions that return Result");
assert_eq!(c.category, QueryCategory::Structural);
assert_eq!(c.strategy, SearchStrategy::DenseWithTypeHints);
}
#[test]
fn test_classify_cross_language_stays_on_enriched() {
let c = classify_query("Python equivalent of map in Rust");
assert_eq!(c.category, QueryCategory::CrossLanguage);
assert_eq!(c.strategy, SearchStrategy::DenseDefault);
}
#[test]
fn test_classify_structural() {
let c = classify_query("functions that return Result");
assert_eq!(c.category, QueryCategory::Structural);
assert_eq!(c.confidence, Confidence::Medium);
}
#[test]
fn test_classify_type_filtered() {
let c = classify_query("all test functions");
assert_eq!(c.category, QueryCategory::TypeFiltered);
assert_eq!(c.strategy, SearchStrategy::DenseBase);
assert!(c.type_hints.is_some());
assert!(c.type_hints.unwrap().contains(&ChunkType::Test));
}
#[test]
fn test_classify_cross_language() {
let c = classify_query("Python equivalent of map in Rust");
assert_eq!(c.category, QueryCategory::CrossLanguage);
assert_eq!(c.confidence, Confidence::High);
}
#[test]
fn test_classify_conceptual() {
let c = classify_query("dependency injection pattern");
assert_eq!(c.category, QueryCategory::Conceptual);
assert_eq!(c.confidence, Confidence::Medium);
}
#[test]
fn test_classify_multi_step() {
let c = classify_query("find errors and then retry them");
assert_eq!(c.category, QueryCategory::MultiStep);
assert_eq!(c.confidence, Confidence::Low);
assert_eq!(c.strategy, SearchStrategy::DenseBase);
}
#[test]
fn test_classify_unknown() {
let c = classify_query("asdf jkl qwerty");
assert_eq!(c.category, QueryCategory::Unknown);
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_extract_type_hints_struct() {
let hints = extract_type_hints("find all structs");
assert!(hints.is_some());
assert!(hints.unwrap().contains(&ChunkType::Struct));
}
#[test]
fn test_extract_type_hints_none() {
let hints = extract_type_hints("handle errors gracefully");
assert!(hints.is_none());
}
#[test]
fn test_classify_empty() {
let c = classify_query("");
assert_eq!(c.category, QueryCategory::Unknown);
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_classify_single_char() {
let c = classify_query("a");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_very_long() {
let long = "a ".repeat(5000);
let start = std::time::Instant::now();
let c = classify_query(&long);
let elapsed = start.elapsed();
assert!(elapsed.as_millis() < 100, "Should complete in <100ms");
assert_eq!(c.confidence, Confidence::Low);
}
#[test]
fn test_classify_unicode_identifier() {
let c = classify_query("日本語_関数");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_path_like() {
let c = classify_query("src/store/mod.rs");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_only_stopwords() {
let c = classify_query("the a an of");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_special_chars() {
let c = classify_query("fn<T: Hash>()");
let _ = c;
}
#[test]
fn test_classify_all_caps() {
let c = classify_query("WHERE IS THE ERROR HANDLER");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_numbers() {
let c = classify_query("404");
assert_eq!(c.category, QueryCategory::Unknown);
}
#[test]
fn test_classify_hex() {
let c = classify_query("0xFF");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_mixed_signals() {
let c = classify_query("not struct");
assert_eq!(c.category, QueryCategory::Negation);
}
#[test]
fn test_classify_sql_injection() {
let c = classify_query("'; DROP TABLE--");
assert_ne!(c.category, QueryCategory::IdentifierLookup);
}
#[test]
fn test_classify_null_bytes() {
let c = classify_query("foo\0bar");
let _ = c;
}
#[test]
fn test_classify_type_hint_wrong_extraction() {
let hints = extract_type_hints("error handling");
assert!(hints.is_none());
}
#[test]
fn test_classify_identifier_common_word() {
let c = classify_query("error");
assert_eq!(c.category, QueryCategory::IdentifierLookup);
}
}