use std::{ffi::OsStr, path::Path};
use crate::{
models::{self},
resolver::{
constant::{self, DEFAULT_SCORE},
utils, weight,
},
};
pub fn get_fuzzy_config(query: &str) -> frizbee::Config {
let has_uppercase = query.chars().any(char::is_uppercase);
frizbee::Config {
max_typos: Some(u16::try_from(query.len().div_euclid(3).min(4)).unwrap_or(0)),
sort: false,
scoring: frizbee::Scoring {
capitalization_bonus: if has_uppercase {
weight::CASE_SENSITIVE_MATCHING_CAPITALISATION_BONUS
} else {
0
},
matching_case_bonus: if has_uppercase {
weight::CASE_SENSITIVE_MATCHING_CASE_BONUS
} else {
0
},
..Default::default()
},
}
}
pub fn fuzzy_match(
query: &str,
symbol: &models::resolved::ResolvedSymbol,
config: &frizbee::Config,
) -> Vec<frizbee::Match> {
frizbee::match_list(
query,
&[
symbol.name.as_str(),
symbol.name.to_lowercase().as_str(),
],
config,
)
}
pub fn calculate_score<'a, 'b>(
query: &str,
symbol: &models::resolved::ResolvedSymbol,
fuzzy_matches: impl Iterator<Item = &'a frizbee::Match>,
current_file: Option<&'b Path>,
) -> i64 {
let filename = if let Some(Some(filename)) = symbol.path.file_name().map(OsStr::to_str) {
Some(filename)
} else {
None
};
let entrypoint_file_penalty = if let Some(filename) = filename
&& utils::is_entrypoint_file(filename)
{
weight::ENTRYPOINT_FILE_SCORE_PENALTY
} else {
0
};
let fuzzy_match_bonus: i64 = fuzzy_matches.map(weight::calculate_fuzzy_match_bonus).sum();
let clear_intent_bonus: i64 = calculate_clear_intent_bonus(query, symbol);
let symbol_kind_bonus = match symbol.kind {
models::parsed::SymbolKind::Function
| models::parsed::SymbolKind::Getter
| models::parsed::SymbolKind::Setter
| models::parsed::SymbolKind::Method
| models::parsed::SymbolKind::Struct
| models::parsed::SymbolKind::Trait
| models::parsed::SymbolKind::Class
| models::parsed::SymbolKind::Constant
| models::parsed::SymbolKind::Enum
| models::parsed::SymbolKind::EnumMember
| models::parsed::SymbolKind::Interface => weight::COMMON_SYMBOL_KINDS_SCORE_BONUS,
models::parsed::SymbolKind::Variable
| models::parsed::SymbolKind::Type
| models::parsed::SymbolKind::TypeAlias => weight::INFREQUENT_SYMBOL_KINDS_SCORE_BONUS,
models::parsed::SymbolKind::Package
| models::parsed::SymbolKind::Module
| models::parsed::SymbolKind::SelfParameter => weight::UNCOMMON_SYMBOL_KINDS_SCORE_PENALTY,
_ => 0,
};
let test_harness_penalty = if utils::is_part_of_test_harness(symbol.path.as_path()) {
weight::TEST_HARNESS_SCORE_PENALTY
} else {
0
};
let distance_penalty = current_file.map_or(0, |current_file| {
if current_file == symbol.path {
return weight::SAME_FILE_SCORE_PENALTY;
}
weight::calculate_distance_score_penalty(utils::get_path_distance(
current_file,
symbol.path.as_path(),
))
});
let autogenerated_code_penalty = if utils::is_autogenerated_code(&symbol.path) {
weight::AUTOGENERATED_CODE_SCORE_PENALTY
} else {
0
};
DEFAULT_SCORE
.saturating_add(entrypoint_file_penalty)
.saturating_add(fuzzy_match_bonus)
.saturating_add(clear_intent_bonus)
.saturating_add(symbol_kind_bonus)
.saturating_add(test_harness_penalty)
.saturating_add(distance_penalty)
.saturating_add(autogenerated_code_penalty)
}
pub fn calculate_clear_intent_bonus(query: &str, symbol: &models::resolved::ResolvedSymbol) -> i64 {
let has_uppercase = query.chars().any(char::is_uppercase);
let has_lowercase = query.chars().any(char::is_lowercase);
let has_underscores = query.chars().any(|c| c == '_');
let is_length_for_clear_intent =
query.len() >= constant::MIN_CLEAR_INTENT_QUERY_LENGTH as usize;
let is_upper_and_lower_mix = has_uppercase && has_lowercase && !has_underscores;
let is_snake_case = has_underscores && !has_uppercase;
let is_screaming_case = has_uppercase && !has_lowercase;
match symbol.kind {
models::parsed::SymbolKind::Constant
| models::parsed::SymbolKind::StaticField
| models::parsed::SymbolKind::StaticVariable
| models::parsed::SymbolKind::StaticDataMember
if is_length_for_clear_intent && is_screaming_case =>
{
weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS
}
models::parsed::SymbolKind::Struct
| models::parsed::SymbolKind::Type
| models::parsed::SymbolKind::TypeAlias
| models::parsed::SymbolKind::Class
| models::parsed::SymbolKind::Enum
| models::parsed::SymbolKind::EnumMember
| models::parsed::SymbolKind::Interface
| models::parsed::SymbolKind::Trait
| models::parsed::SymbolKind::Protocol
| models::parsed::SymbolKind::Union
| models::parsed::SymbolKind::Variable if is_length_for_clear_intent && is_upper_and_lower_mix =>
{
weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS
}
models::parsed::SymbolKind::Function
| models::parsed::SymbolKind::Method
| models::parsed::SymbolKind::Predicate
| models::parsed::SymbolKind::TraitMethod
| models::parsed::SymbolKind::ProtocolMethod
| models::parsed::SymbolKind::AbstractMethod
| models::parsed::SymbolKind::Getter
if is_length_for_clear_intent
&& (is_snake_case || is_upper_and_lower_mix ) =>
{
weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS
}
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::path::PathBuf;
use crate::{
models::{
parsed::{Language, SymbolKind},
resolved::{ResolvedSymbol, Score},
},
resolver::scoring::DEFAULT_SCORE,
};
#[test]
pub fn test_scoring_struct_in_entrypoint_file() {
let symbol = ResolvedSymbol {
id: 1,
name: "ResolvedSymbol".to_string(),
kind: SymbolKind::Struct,
language: Language::Rust,
path: PathBuf::from("/some/file/mod.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 14,
};
let score = super::calculate_score("", &symbol, Vec::new().iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score += 35; target_score -= 2;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_struct_where_path_has_no_filename() {
let symbol = ResolvedSymbol {
id: 1,
name: "ResolvedSymbol".to_string(),
kind: SymbolKind::Struct,
language: Language::Rust,
path: PathBuf::from("/some/file"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 14,
};
let score = super::calculate_score("", &symbol, Vec::new().iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score += 35;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_variable_in_far_away_file() {
let symbol = ResolvedSymbol {
id: 1,
name: "ResolvedSymbol".to_string(),
kind: SymbolKind::Variable,
language: Language::Rust,
path: PathBuf::from_iter(["", "some", "file", "over", "here", "file.rs"]),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 14,
};
let score = super::calculate_score(
"",
&symbol,
Vec::new().iter(),
Some(&PathBuf::from_iter([
"a",
"totally",
"different",
"file",
"over",
"there",
"file.ts",
])),
);
let mut target_score = DEFAULT_SCORE;
target_score += 15; target_score -= 12;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_variable_in_same_file() {
let symbol = ResolvedSymbol {
id: 1,
name: "ResolvedSymbol".to_string(),
kind: SymbolKind::Variable,
language: Language::Rust,
path: PathBuf::from_iter(["", "some", "file", "over", "here", "file.rs"]),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 14,
};
let score = super::calculate_score(
"",
&symbol,
Vec::new().iter(),
Some(&PathBuf::from_iter([
"", "some", "file", "over", "here", "file.rs",
])),
);
let mut target_score = DEFAULT_SCORE;
target_score += 15; target_score -= 10;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_module_symbol() {
let symbol = ResolvedSymbol {
id: 1,
name: "tests".to_string(),
kind: SymbolKind::Module,
language: Language::Rust,
path: PathBuf::from("some_module.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 14,
};
let score = super::calculate_score("", &symbol, Vec::new().iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score -= 15;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_class_in_test_file() {
let symbol = ResolvedSymbol {
id: 1,
name: "TestClass".to_string(),
kind: SymbolKind::Class,
language: Language::TypeScript,
path: PathBuf::from("some_file.test.ts"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 9,
};
let score = super::calculate_score("", &symbol, Vec::new().iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score += 35; target_score -= 10;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_type_in_autogenerated_folder() {
let symbol = ResolvedSymbol {
id: 1,
name: "TestClass".to_string(),
kind: SymbolKind::Class,
language: Language::TypeScript,
path: PathBuf::from("path/__generated__/index.ts"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 9,
};
let score = super::calculate_score("", &symbol, Vec::new().iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score += 35; target_score -= 2; target_score -= 50;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_fuzzy_matched_symbol() {
let query = "Lem";
let name = "TestLemma".to_string();
let path = PathBuf::from_iter(["some", "file", "over", "there.ts"]);
let symbol = ResolvedSymbol {
id: 1,
name: name.clone(),
kind: SymbolKind::Lemma,
language: Language::TypeScript,
path: path.clone(),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 9,
};
let config = frizbee::Config {
max_typos: Some(1),
sort: false,
scoring: frizbee::Scoring::default(),
};
let fuzzy_matches = frizbee::match_list(
query,
&[
format!("{}:{name}", path.to_str().unwrap()).as_str(),
name.as_str(),
],
&config,
);
let score = super::calculate_score(&query, &symbol, fuzzy_matches.iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score += 26;
assert_eq!(target_score, score);
}
#[test]
pub fn test_scoring_fuzzy_matched_symbol_smartcase_insensitive() {
let query = "lem";
let name = "TestLemma".to_string();
let path = PathBuf::from_iter(["some", "file", "over", "there.ts"]);
let symbol = ResolvedSymbol {
id: 1,
name: name.clone(),
kind: SymbolKind::Lemma,
language: Language::Clojure,
path: path.clone(),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 9,
};
let config = frizbee::Config {
max_typos: Some(1),
sort: false,
scoring: frizbee::Scoring::default(),
};
let fuzzy_matches = frizbee::match_list(
query,
&[
format!("{}:{name}", path.to_str().unwrap()).as_str(),
name.as_str(),
],
&config,
);
let score = super::calculate_score(&query, &symbol, fuzzy_matches.iter(), None);
let mut target_score = DEFAULT_SCORE;
target_score += 24;
assert_eq!(target_score, score);
}
#[rstest]
#[case(SymbolKind::Constant)]
#[case(SymbolKind::StaticField)]
#[case(SymbolKind::StaticVariable)]
#[case(SymbolKind::StaticDataMember)]
fn test_constant_screaming_case_intent(#[case] kind: SymbolKind) {
let query = "MAXSIZE";
let sym = ResolvedSymbol {
id: 1,
name: "MAXSIZE".to_string(),
kind,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS);
}
#[test]
fn test_constant_non_screaming_case_no_bonus() {
let query = "maxSize";
let sym = ResolvedSymbol {
id: 1,
name: "MAXSIZE".to_string(),
kind: SymbolKind::Constant,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, 0);
}
#[rstest]
#[case(SymbolKind::Struct)]
#[case(SymbolKind::Type)]
#[case(SymbolKind::TypeAlias)]
#[case(SymbolKind::Class)]
#[case(SymbolKind::Enum)]
#[case(SymbolKind::EnumMember)]
#[case(SymbolKind::Interface)]
#[case(SymbolKind::Trait)]
#[case(SymbolKind::Protocol)]
#[case(SymbolKind::Union)]
fn test_type_like_upper_lower_mix(#[case] kind: SymbolKind) {
let query = "UserProfile";
let sym = ResolvedSymbol {
id: 1,
name: "UserProfile".to_string(),
kind,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS);
}
#[test]
fn test_type_like_snake_case_no_bonus() {
let query = "user_profile";
let sym = ResolvedSymbol {
id: 1,
name: "UserProfile".to_string(),
kind: SymbolKind::Struct,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, 0);
}
#[rstest]
#[case("is_ready")]
#[case("isReady")]
fn test_function_intent(#[case] query: &str) {
let sym = ResolvedSymbol {
id: 1,
name: query.to_string(),
kind: SymbolKind::Function,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS);
}
#[rstest]
#[case(SymbolKind::Getter)]
#[case(SymbolKind::Function)]
#[case(SymbolKind::Method)]
#[case(SymbolKind::Predicate)]
fn test_predicate_intent(#[case] kind: SymbolKind) {
let query = "is_ready";
let sym = ResolvedSymbol {
id: 1,
name: "is_ready".to_string(),
kind,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, weight::CLEAR_QUERY_INTENT_SYMBOL_KINDS_SCORE_BONUS);
}
#[rstest]
#[case("is", SymbolKind::Function)]
#[case("maxSize", SymbolKind::Constant)]
#[case("is_ready", SymbolKind::Struct)]
fn test_no_bonus_cases(#[case] query: &str, #[case] kind: SymbolKind) {
let sym = ResolvedSymbol {
id: 1,
name: "irrelevant".to_string(),
kind,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, 0);
}
#[test]
fn test_short_query_no_bonus() {
let query = "is";
let sym = ResolvedSymbol {
id: 1,
name: "is_ready".to_string(),
kind: SymbolKind::Function,
language: Language::Rust,
path: PathBuf::from("src/lib.rs"),
score: Score::default(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 1,
};
let score = calculate_clear_intent_bonus(query, &sym);
assert_eq!(score, 0);
}
}