use std::collections::HashSet;
use crate::nl::tokenize_identifier;
use super::config::ScoringConfig;
pub(crate) fn is_name_like_query(query: &str) -> bool {
let words: Vec<&str> = query.split_whitespace().collect();
if words.len() <= 2 {
return true;
}
const NL_WORDS: &[&str] = &[
"the",
"a",
"an",
"is",
"are",
"was",
"were",
"that",
"which",
"how",
"what",
"where",
"when",
"does",
"do",
"can",
"should",
"would",
"could",
"for",
"with",
"from",
"into",
"this",
"these",
"those",
"function",
"method",
"code",
"implement",
"find",
"search",
];
let lower = query.to_lowercase();
let lower_words: Vec<&str> = lower.split_whitespace().collect();
for w in &lower_words {
if NL_WORDS.contains(w) {
return false;
}
}
if words.len() >= 3 && lower == query && !query.contains('_') {
return false;
}
true
}
pub(crate) struct NameMatcher {
query_lower: String,
query_words: Vec<String>,
}
impl NameMatcher {
pub fn new(query: &str) -> Self {
Self {
query_lower: query.to_lowercase(),
query_words: tokenize_identifier(query),
}
}
pub fn score(&self, name: &str) -> f32 {
let cfg = &ScoringConfig::DEFAULT;
let name_lower = name.to_lowercase();
if name_lower == self.query_lower {
return cfg.name_exact;
}
if name_lower.contains(&self.query_lower) {
return cfg.name_contains;
}
if self.query_lower.contains(&name_lower) {
return cfg.name_contained_by;
}
if self.query_words.is_empty() {
return 0.0;
}
let name_words: Vec<String> = tokenize_identifier(name);
if name_words.is_empty() {
return 0.0;
}
let name_word_set: HashSet<&str> = name_words.iter().map(String::as_str).collect();
let overlap = self
.query_words
.iter()
.filter(|w| {
if name_word_set.contains(w.as_str()) {
return true;
}
name_words.iter().any(|nw| {
(nw.len() > w.len() && nw.contains(w.as_str()))
|| (w.len() > nw.len() && w.contains(nw.as_str()))
})
})
.count() as f32;
let total = self.query_words.len().max(1) as f32;
(overlap / total) * cfg.name_max_overlap
}
}
#[cfg(test)]
pub(crate) fn name_match_score(query: &str, name: &str) -> f32 {
NameMatcher::new(query).score(name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_name_match_exact() {
assert_eq!(name_match_score("parse", "parse"), 1.0);
}
#[test]
fn test_name_match_contains() {
assert_eq!(name_match_score("parse", "parseConfig"), 0.8);
}
#[test]
fn test_name_match_contained() {
assert_eq!(name_match_score("parseConfigFile", "parse"), 0.6);
}
#[test]
fn test_name_match_partial_overlap() {
let score = name_match_score("parseConfig", "configParser");
assert!(score > 0.0 && score <= 0.5);
}
#[test]
fn test_name_match_no_match() {
assert_eq!(name_match_score("foo", "bar"), 0.0);
}
#[test]
fn test_name_like_single_token() {
assert!(is_name_like_query("parseConfig"));
assert!(is_name_like_query("CircuitBreaker"));
assert!(is_name_like_query("handle_error"));
}
#[test]
fn test_name_like_two_tokens() {
assert!(is_name_like_query("parse config"));
assert!(is_name_like_query("error handler"));
}
#[test]
fn test_nl_query_with_indicators() {
assert!(!is_name_like_query("function that handles errors"));
assert!(!is_name_like_query("how does parsing work"));
assert!(!is_name_like_query("find error handling code"));
assert!(!is_name_like_query("code that implements retry logic"));
}
#[test]
fn test_nl_query_all_lowercase_3_plus_words() {
assert!(!is_name_like_query("error handling retry"));
}
#[test]
fn test_name_like_snake_case_multi() {
assert!(is_name_like_query("handle_error_retry"));
}
}