use crate::search::ranking::QueryType;
use crate::search::search::SearchQuery;
use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::HashSet;
use unicode_normalization::UnicodeNormalization;
pub const MAX_QUERY_LENGTH: usize = 500;
pub const MIN_QUERY_LENGTH: usize = 1;
pub const MAX_TOP_K: usize = 1000;
pub const MIN_TOP_K: usize = 1;
pub const MAX_TOKEN_BUDGET: usize = 10000;
pub const DEFAULT_TOKEN_BUDGET: usize = 2000;
pub const MAX_EMBEDDING_DIMENSION: usize = 10000;
pub const MIN_EMBEDDING_DIMENSION: usize = 1;
static HOW_WORKS_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i)^(?:show|tell|explain|describe)\s+(?:me\s+)?how\s+(?:does\s+)?\S.{0,400}?(?:\s+(?:work|works|working|function|functions|operate|operates))?\s*\.?\s*$"
).expect("Failed to compile HOW_WORKS_PATTERN")
});
static WHERE_HANDLED_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i)^where\s+(?:is|are|do\s+we\s+handle|does\s+.\s+handle)\s+\S.{0,400}?(?:\s+handled)?\s*\.?\s*$"
).expect("Failed to compile WHERE_HANDLED_PATTERN")
});
static BOTTLENECKS_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i)^(?:what|where|find)\s+(?:are\s+)?(?:the\s+)?(?:bottlenecks|performance\s+issues|slow\s+code|optimization\s+opportunities)\s*\.?\s*$"
).expect("Failed to compile BOTTLENECKS_PATTERN")
});
static COMPLEXITY_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i)^(?:most|least)\s+(?:complex|complicated|difficult|simple)(?:\s+\S.{0,100})?\s*\.?\s*$"
).expect("Failed to compile COMPLEXITY_PATTERN")
});
static STOP_WORDS: Lazy<HashSet<&'static str>> = Lazy::new(|| {
[
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
"from", "as", "is", "was", "are", "were", "been", "be", "have", "has", "had", "do", "does",
"did", "will", "would", "could", "should", "may", "might", "must", "shall", "can", "need",
"show", "me", "tell", "explain", "describe", "how", "what", "where", "when", "why",
"which", "that", "this", "these", "those",
]
.iter()
.cloned()
.collect()
});
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryIntent {
HowWorks,
WhereHandled,
Bottlenecks,
Semantic,
Text,
}
#[derive(Debug, Clone)]
pub struct ParsedQuery {
pub original: String,
pub terms: Vec<String>,
pub intent: QueryIntent,
pub query_type: QueryType,
pub expand_context: bool,
pub top_k: usize,
pub token_budget: Option<usize>,
}
impl ParsedQuery {
fn validate(&self) -> Result<(), Error> {
if self.top_k < MIN_TOP_K || self.top_k > MAX_TOP_K {
return Err(Error::InvalidTopK {
provided: self.top_k,
min: MIN_TOP_K,
max: MAX_TOP_K,
});
}
if let Some(budget) = self.token_budget {
if budget > MAX_TOKEN_BUDGET {
return Err(Error::TokenBudgetTooLarge {
provided: budget,
max: MAX_TOKEN_BUDGET,
});
}
}
Ok(())
}
}
pub struct QueryParser {
}
impl QueryParser {
#[allow(clippy::new_without_default)]
pub fn new() -> Result<Self, Error> {
Ok(Self {})
}
pub fn parse(&self, query: &str, default_top_k: usize) -> Result<ParsedQuery, Error> {
let query = self.validate_and_sanitize_query(query)?;
let top_k = self.validate_top_k(default_top_k)?;
let intent = self.detect_intent(&query);
let terms = self.extract_terms(&query, &intent)?;
if terms.is_empty() {
return Err(Error::NoMeaningfulTerms {
query: self.truncate_for_error(&query),
suggestion: "Try using more specific terms or complete sentences",
});
}
let query_type = self.classify_query(&intent);
let expand_context = matches!(intent, QueryIntent::HowWorks | QueryIntent::WhereHandled);
let token_budget = if expand_context {
Some(DEFAULT_TOKEN_BUDGET)
} else {
None
};
let parsed = ParsedQuery {
original: query,
terms,
intent,
query_type,
expand_context,
top_k,
token_budget,
};
parsed.validate()?;
Ok(parsed)
}
fn validate_and_sanitize_query(&self, query: &str) -> Result<String, Error> {
if query.contains('\0') {
return Err(Error::InvalidCharacters {
reason: "Query contains null bytes".to_string(),
});
}
if query.len() > MAX_QUERY_LENGTH {
return Err(Error::QueryTooLong {
provided: query.len(),
max: MAX_QUERY_LENGTH,
actual_prefix: self.truncate_for_error(query),
});
}
for ch in query.chars() {
if ch.is_control() && !ch.is_whitespace() {
return Err(Error::InvalidCharacters {
reason: format!("Query contains control character: U+{:04X}", ch as u32),
});
}
}
let normalized = query.nfc().collect::<String>();
let trimmed = normalized.trim();
if trimmed.is_empty() {
return Err(Error::EmptyQuery);
}
if trimmed.len() > MAX_QUERY_LENGTH {
return Err(Error::QueryTooLong {
provided: trimmed.len(),
max: MAX_QUERY_LENGTH,
actual_prefix: self.truncate_for_error(trimmed),
});
}
Ok(trimmed.to_string())
}
fn validate_top_k(&self, top_k: usize) -> Result<usize, Error> {
if !(MIN_TOP_K..=MAX_TOP_K).contains(&top_k) {
return Err(Error::InvalidTopK {
provided: top_k,
min: MIN_TOP_K,
max: MAX_TOP_K,
});
}
Ok(top_k)
}
fn detect_intent(&self, query: &str) -> QueryIntent {
if HOW_WORKS_PATTERN.is_match(query) {
return QueryIntent::HowWorks;
}
if WHERE_HANDLED_PATTERN.is_match(query) {
return QueryIntent::WhereHandled;
}
if BOTTLENECKS_PATTERN.is_match(query) {
return QueryIntent::Bottlenecks;
}
if COMPLEXITY_PATTERN.is_match(query) {
return QueryIntent::Bottlenecks;
}
let query_lower = query.to_lowercase();
if query_lower.contains("how")
|| query_lower.contains("what")
|| query_lower.contains("why")
|| query_lower.contains("where")
|| query_lower.contains("when")
|| query_lower.contains("which")
{
return QueryIntent::Semantic;
}
QueryIntent::Text
}
fn extract_terms(&self, query: &str, intent: &QueryIntent) -> Result<Vec<String>, Error> {
let mut terms = Vec::new();
match intent {
QueryIntent::HowWorks => {
if let Some(captures) = HOW_WORKS_PATTERN.captures(query) {
let full_match = captures.get(0).map(|m| m.as_str()).unwrap_or("");
let subject = full_match
.to_lowercase()
.replace("show me how ", "")
.replace("show how ", "")
.replace("tell me how ", "")
.replace("tell how ", "")
.replace("explain how ", "")
.replace("describe how ", "")
.replace(" how does ", "")
.replace(" how ", "")
.replace(" works", "")
.replace(" work", "")
.replace(" functions", "")
.replace(" function", "")
.replace(" operates", "")
.replace(" operate", "");
let subject = subject.trim();
if !subject.is_empty() {
terms.extend(self.tokenize(subject));
}
}
if terms.is_empty() {
terms.extend(self.tokenize(query));
}
}
QueryIntent::WhereHandled => {
if let Some(captures) = WHERE_HANDLED_PATTERN.captures(query) {
let full_match = captures.get(0).map(|m| m.as_str()).unwrap_or("");
let subject = full_match
.to_lowercase()
.replace("where is ", "")
.replace("where are ", "")
.replace("where do we handle ", "")
.replace("where does ", "")
.replace(" handled", "");
let subject = subject.trim();
if !subject.is_empty() {
terms.extend(self.tokenize(subject));
}
}
if terms.is_empty() {
terms.extend(self.tokenize(query));
}
}
QueryIntent::Bottlenecks => {
terms.extend(self.tokenize(query));
}
QueryIntent::Semantic | QueryIntent::Text => {
terms.extend(self.tokenize(query));
}
}
let filtered: Vec<String> = terms
.into_iter()
.filter(|t| !STOP_WORDS.contains(t.as_str()))
.collect::<HashSet<_>>()
.into_iter()
.collect();
Ok(filtered)
}
fn tokenize(&self, text: &str) -> Vec<String> {
text.split_whitespace()
.map(|s| s.to_lowercase())
.map(|s| {
s.trim_end_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|s| s.len() >= 3)
.collect()
}
fn classify_query(&self, intent: &QueryIntent) -> QueryType {
match intent {
QueryIntent::HowWorks | QueryIntent::Semantic => QueryType::Semantic,
QueryIntent::WhereHandled => QueryType::Structural,
QueryIntent::Bottlenecks => QueryType::Structural,
QueryIntent::Text => QueryType::Text,
}
}
pub fn build_search_query(&self, parsed: &ParsedQuery) -> SearchQuery {
let query_text = if parsed.terms.is_empty() {
parsed.original.clone()
} else {
parsed.terms.join(" ")
};
SearchQuery {
query: query_text,
top_k: parsed.top_k,
token_budget: parsed.token_budget,
semantic: matches!(parsed.query_type, QueryType::Semantic),
expand_context: parsed.expand_context,
query_embedding: None,
threshold: None,
query_type: Some(parsed.query_type),
}
}
fn truncate_for_error(&self, query: &str) -> String {
if query.len() <= 100 {
query.to_string()
} else {
format!("{}...", &query[..97])
}
}
}
impl Default for QueryParser {
fn default() -> Self {
Self::new().expect("QueryParser::new should never fail")
}
}
unsafe impl Send for QueryParser {}
unsafe impl Sync for QueryParser {}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Query cannot be empty")]
EmptyQuery,
#[error("Query too long: {provided} characters (max: {max}). Query: '{actual_prefix}'")]
QueryTooLong {
provided: usize,
max: usize,
actual_prefix: String,
},
#[error("Query contains invalid characters: {reason}")]
InvalidCharacters {
reason: String,
},
#[error("Invalid top_k value: {provided} (must be between {min} and {max})")]
InvalidTopK {
provided: usize,
min: usize,
max: usize,
},
#[error("Token budget too large: {provided} (max: {max})")]
TokenBudgetTooLarge {
provided: usize,
max: usize,
},
#[error("Query contains no meaningful terms: '{query}'. {suggestion}")]
NoMeaningfulTerms {
query: String,
suggestion: &'static str,
},
#[error("Invalid regex pattern: {0}")]
InvalidPattern(String),
#[error("Query parsing failed: {0}")]
ParseFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_parser_creation() {
let parser = QueryParser::new();
assert!(parser.is_ok());
}
#[test]
fn test_default_parser() {
let parser = QueryParser::default();
let parsed = parser.parse("test query", 10);
assert!(parsed.is_ok());
}
#[test]
fn test_parse_empty_query() {
let parser = QueryParser::new().unwrap();
let result = parser.parse("", 10);
assert!(matches!(result, Err(Error::EmptyQuery)));
}
#[test]
fn test_parse_whitespace_only_query() {
let parser = QueryParser::new().unwrap();
let result = parser.parse(" \n\t ", 10);
assert!(matches!(result, Err(Error::EmptyQuery)));
}
#[test]
fn test_parse_how_works_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser
.parse("show me how authentication works", 10)
.unwrap();
assert_eq!(parsed.intent, QueryIntent::HowWorks);
assert_eq!(parsed.query_type, QueryType::Semantic);
assert!(parsed.expand_context);
assert!(parsed.terms.contains(&"authentication".to_string()));
}
#[test]
fn test_parse_where_handled_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("where is error handling handled", 10).unwrap();
assert_eq!(parsed.intent, QueryIntent::WhereHandled);
assert_eq!(parsed.query_type, QueryType::Structural);
assert!(parsed.expand_context);
assert!(parsed.terms.contains(&"error".to_string()));
assert!(parsed.terms.contains(&"handling".to_string()));
}
#[test]
fn test_parse_bottlenecks_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("what are the bottlenecks", 10).unwrap();
assert_eq!(parsed.intent, QueryIntent::Bottlenecks);
assert_eq!(parsed.query_type, QueryType::Structural);
assert!(!parsed.expand_context);
}
#[test]
fn test_parse_semantic_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("how do I implement caching", 10).unwrap();
assert_eq!(parsed.intent, QueryIntent::Semantic);
assert_eq!(parsed.query_type, QueryType::Semantic);
}
#[test]
fn test_parse_text_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("function_name", 10).unwrap();
assert_eq!(parsed.intent, QueryIntent::Text);
assert_eq!(parsed.query_type, QueryType::Text);
assert!(!parsed.expand_context);
}
#[test]
fn test_build_search_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("show me how parsing works", 10).unwrap();
let search_query = parser.build_search_query(&parsed);
assert_eq!(search_query.top_k, 10);
assert!(search_query.semantic);
assert!(search_query.expand_context);
assert!(search_query.token_budget.is_some());
}
#[test]
fn test_tokenize() {
let parser = QueryParser::new().unwrap();
let tokens = parser.tokenize("Hello World Test");
assert_eq!(tokens.len(), 3);
assert!(tokens.contains(&"hello".to_string()));
assert!(tokens.contains(&"world".to_string()));
}
#[test]
fn test_tokenize_filters_short_words() {
let parser = QueryParser::new().unwrap();
let tokens = parser.tokenize("a an the of in");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0], "the");
}
#[test]
fn test_stop_words_filtering() {
let parser = QueryParser::new().unwrap();
let parsed = parser
.parse("show me how the authentication system works", 10)
.unwrap();
assert!(!parsed.terms.contains(&"show".to_string()));
assert!(!parsed.terms.contains(&"me".to_string()));
assert!(!parsed.terms.contains(&"the".to_string()));
assert!(parsed.terms.contains(&"authentication".to_string()));
assert!(parsed.terms.contains(&"system".to_string()));
}
#[test]
fn test_complexity_query() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("most complex functions", 10).unwrap();
assert_eq!(parsed.intent, QueryIntent::Bottlenecks);
assert_eq!(parsed.query_type, QueryType::Structural);
}
#[test]
fn test_query_with_question_words() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("what does this function do", 10).unwrap();
assert_eq!(parsed.intent, QueryIntent::Semantic);
}
#[test]
fn test_query_too_long() {
let parser = QueryParser::new().unwrap();
let long_query = "a".repeat(MAX_QUERY_LENGTH + 1);
let result = parser.parse(&long_query, 10);
assert!(matches!(result, Err(Error::QueryTooLong { .. })));
}
#[test]
fn test_query_exactly_max_length() {
let parser = QueryParser::new().unwrap();
let query = "a".repeat(MAX_QUERY_LENGTH);
let result = parser.parse(&query, 10);
assert!(result.is_ok());
}
#[test]
fn test_invalid_top_k_too_small() {
let parser = QueryParser::new().unwrap();
let result = parser.parse("test query", 0);
assert!(matches!(result, Err(Error::InvalidTopK { .. })));
}
#[test]
fn test_invalid_top_k_too_large() {
let parser = QueryParser::new().unwrap();
let result = parser.parse("test query", MAX_TOP_K + 1);
assert!(matches!(result, Err(Error::InvalidTopK { .. })));
}
#[test]
fn test_valid_top_k_boundaries() {
let parser = QueryParser::new().unwrap();
assert!(parser.parse("test", MIN_TOP_K).is_ok());
assert!(parser.parse("test", MAX_TOP_K).is_ok());
}
#[test]
fn test_query_with_null_bytes() {
let parser = QueryParser::new().unwrap();
let result = parser.parse("test\x00query", 10);
assert!(matches!(result, Err(Error::InvalidCharacters { .. })));
}
#[test]
fn test_query_with_control_characters() {
let parser = QueryParser::new().unwrap();
let result = parser.parse("test\x01query", 10);
assert!(matches!(result, Err(Error::InvalidCharacters { .. })));
}
#[test]
fn test_unicode_normalization() {
let parser = QueryParser::new().unwrap();
let query1 = parser.parse("café", 10);
let query2 = parser.parse("cafe\u{301}", 10); assert!(query1.is_ok());
assert!(query2.is_ok());
assert_eq!(query1.unwrap().terms, query2.unwrap().terms);
}
#[test]
fn test_no_meaningful_terms() {
let parser = QueryParser::new().unwrap();
let result = parser.parse("the a an of in", 10);
assert!(matches!(result, Err(Error::NoMeaningfulTerms { .. })));
}
#[test]
fn test_stop_words_set_is_efficient() {
let addr1 = &*STOP_WORDS as *const _ as usize;
let addr2 = &*STOP_WORDS as *const _ as usize;
assert_eq!(addr1, addr2, "STOP_WORDS should be statically allocated");
}
#[test]
fn test_regex_patterns_are_static() {
let pattern1 = &*HOW_WORKS_PATTERN as *const _ as usize;
let pattern2 = &*HOW_WORKS_PATTERN as *const _ as usize;
assert_eq!(
pattern1, pattern2,
"Regex patterns should be statically allocated"
);
}
#[test]
fn test_parser_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<QueryParser>();
}
#[test]
fn test_parsed_query_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ParsedQuery>();
}
#[test]
fn test_parse_performance() {
let parser = QueryParser::new().unwrap();
let queries = vec![
"show me how authentication works",
"where is error handling handled",
"what are the bottlenecks",
"most complex functions",
"how do I implement caching",
"function_name",
];
for query in &queries {
let result1 = parser.parse(query, 10);
let result2 = parser.parse(query, 10);
let result3 = parser.parse(query, 10);
assert!(result1.is_ok(), "Failed to parse: {}", query);
let intent1 = result1.unwrap().intent;
let intent2 = result2.unwrap().intent;
let intent3 = result3.unwrap().intent;
assert_eq!(intent1, intent2);
assert_eq!(intent2, intent3);
}
}
#[test]
fn test_extracted_terms_are_normalized() {
let parser = QueryParser::new().unwrap();
let parsed = parser
.parse("Show Me How AUTHENTICATION Works", 10)
.unwrap();
for term in &parsed.terms {
assert_eq!(
term.to_lowercase(),
*term,
"Term '{}' should be lowercase",
term
);
}
}
#[test]
fn test_terms_no_duplicates() {
let parser = QueryParser::new().unwrap();
let parsed = parser.parse("test test test testing", 10).unwrap();
let unique_terms: std::collections::HashSet<_> = parsed.terms.iter().collect();
assert_eq!(
parsed.terms.len(),
unique_terms.len(),
"Terms should not contain duplicates"
);
}
}