use crate::{
constraints::apply_constraints,
git::is_modified_status,
path_utils::calculate_distance_penalty,
sort_buffer::{sort_by_key_with_buffer, sort_with_buffer},
types::{FileItem, Score, ScoringContext},
};
use fff_query_parser::FuzzyQuery;
use neo_frizbee::Scoring;
use rayon::prelude::*;
use std::path::MAIN_SEPARATOR;
pub(crate) enum FileItems<'a> {
All(&'a [FileItem]),
Filtered(Vec<&'a FileItem>),
}
impl<'a> FileItems<'a> {
#[inline]
#[allow(dead_code)]
fn len(&self) -> usize {
match self {
FileItems::All(s) => s.len(),
FileItems::Filtered(v) => v.len(),
}
}
#[inline]
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn get(&self, index: usize) -> Option<&'a FileItem> {
match self {
FileItems::All(s) => s.get(index),
FileItems::Filtered(v) => v.get(index).copied(),
}
}
fn relative_paths(&self) -> Vec<&'a str> {
match self {
FileItems::All(s) => s.iter().map(|f| f.relative_path.as_str()).collect(),
FileItems::Filtered(v) => v.iter().map(|f| f.relative_path.as_str()).collect(),
}
}
#[inline]
fn index(&self, index: usize) -> &'a FileItem {
match self {
FileItems::All(s) => &s[index],
FileItems::Filtered(v) => v[index],
}
}
}
fn match_fuzzy_parts(
fuzzy_parts: &[&str],
working_files: &FileItems<'_>,
options: &neo_frizbee::Config,
) -> Vec<neo_frizbee::Match> {
if fuzzy_parts.is_empty() {
return vec![];
}
let haystack: Vec<&str> = working_files.relative_paths();
let valid_parts: Vec<&str> = fuzzy_parts
.iter()
.copied()
.filter(|p| p.len() >= 2)
.collect();
if valid_parts.is_empty() {
tracing::debug!("match_fuzzy_parts: no valid parts after filtering, returning empty");
return vec![];
}
if valid_parts.len() == 1 {
let matches = neo_frizbee::match_list(valid_parts[0], &haystack, options);
return matches;
}
let mut matches = neo_frizbee::match_list(valid_parts[0], &haystack, options);
for part in valid_parts[1..].iter() {
let mut part_options = *options;
part_options.max_typos = options.max_typos.map(|t| t.min(part.len() as u16));
matches = matches
.into_iter()
.filter_map(|mut m| {
let path = haystack.get(m.index as usize)?;
let part_matches = neo_frizbee::match_list(part, &[*path], &part_options);
let part_match = part_matches.first()?;
let total = (m.score as u32).saturating_add(part_match.score as u32);
m.score = total.min(u16::MAX as u32) as u16;
Some(m)
})
.collect();
if matches.is_empty() {
break;
}
}
matches
}
pub fn match_and_score_files<'a>(
files: &'a [FileItem],
context: &ScoringContext,
) -> (Vec<&'a FileItem>, Vec<Score>, usize) {
if files.is_empty() {
return (vec![], vec![], 0);
}
let parsed = context.query;
let working_files: FileItems<'a> = if parsed.constraints.is_empty() {
FileItems::All(files)
} else {
match apply_constraints(files, &parsed.constraints) {
Some(filtered) if !filtered.is_empty() => FileItems::Filtered(filtered),
Some(_) => {
return (vec![], vec![], 0);
}
None => FileItems::All(files),
}
};
let query_trimmed: &str = parsed.raw_query.trim();
let single_part_storage: [&str; 1] = [query_trimmed];
let fuzzy_parts: &[&str] = match &parsed.fuzzy_query {
FuzzyQuery::Text(t) if t.len() >= 2 => std::slice::from_ref(t),
FuzzyQuery::Parts(parts) if !parts.is_empty() => parts.as_slice(),
FuzzyQuery::Text(_) | FuzzyQuery::Parts(_) => {
return score_filtered_by_frecency(&working_files, context);
}
FuzzyQuery::Empty => {
if query_trimmed.len() < 2 {
return score_filtered_by_frecency(&working_files, context);
}
&single_part_storage
}
};
let has_uppercase = fuzzy_parts
.iter()
.any(|p| p.chars().any(|c| c.is_uppercase()));
let query_contains_path_separator = fuzzy_parts.iter().any(|p| p.contains(MAIN_SEPARATOR));
let options = neo_frizbee::Config {
max_typos: Some(context.max_typos),
sort: false,
scoring: Scoring {
capitalization_bonus: if has_uppercase { 8 } else { 0 },
matching_case_bonus: if has_uppercase { 4 } else { 0 },
..Default::default()
},
};
let path_matches = match_fuzzy_parts(fuzzy_parts, &working_files, &options);
let primary_text = fuzzy_parts[0]; let haystack_of_filenames: Vec<&str> = path_matches
.iter()
.filter_map(|m| {
working_files
.get(m.index as usize)
.map(|f| f.file_name.as_str())
})
.collect();
let filename_matches = if query_contains_path_separator {
vec![]
} else {
let mut list = if haystack_of_filenames.len() > 1000 {
neo_frizbee::match_list_parallel(
primary_text,
&haystack_of_filenames,
&options,
context.max_threads,
)
} else {
neo_frizbee::match_list(primary_text, &haystack_of_filenames, &options)
};
if list.len() > 1000 {
list.par_sort_unstable_by_key(|m| m.index);
} else {
sort_by_key_with_buffer(&mut list, |m| m.index);
}
list
};
let mut next_filename_match_index = 0;
let results: Vec<_> = path_matches
.into_iter()
.enumerate()
.map(|(index, path_match)| {
let file_idx = path_match.index as usize;
let file = working_files.index(file_idx);
let mut base_score = path_match.score as i32;
let frecency_boost = base_score.saturating_mul(file.total_frecency_score as i32) / 100;
let git_status_boost = if file.git_status.is_some_and(is_modified_status) {
base_score * 15 / 100
} else {
0
};
let distance_penalty =
calculate_distance_penalty(context.current_file, &file.relative_path);
let filename_match = filename_matches
.get(next_filename_match_index)
.and_then(|m| {
if m.index == index as u32 {
next_filename_match_index += 1;
Some(m)
} else {
None
}
});
let mut has_special_filename_bonus = false;
let filename_bonus = match filename_match {
Some(filename_match) if filename_match.exact => {
filename_match.score as i32 / 5 * 2 }
Some(filename_match)
if filename_match.score >= path_match.score
&& !query_contains_path_separator =>
{
base_score = filename_match.score as i32;
(base_score / 6)
.min(30)
}
None if is_special_entry_point_file(&file.file_name) => {
has_special_filename_bonus = true;
base_score * 5 / 100
}
_ => 0,
};
let current_file_penalty = calculate_current_file_penalty(file, base_score, context);
let combo_match_boost = {
let last_same_query_match = context
.last_same_query_match
.as_ref()
.filter(|m| m.file_path.as_os_str() == file.path.as_os_str());
match last_same_query_match {
Some(_) if context.min_combo_count == 0 => 1000,
Some(combo_match) if combo_match.open_count >= context.min_combo_count => {
combo_match.open_count as i32 * context.combo_boost_score_multiplier
}
Some(combo_match) => combo_match.open_count as i32 * 5,
_ => 0,
}
};
let total = base_score
.saturating_add(frecency_boost)
.saturating_add(git_status_boost)
.saturating_add(distance_penalty)
.saturating_add(filename_bonus)
.saturating_add(current_file_penalty)
.saturating_add(combo_match_boost);
let score = Score {
total,
base_score,
current_file_penalty,
filename_bonus,
special_filename_bonus: if has_special_filename_bonus {
filename_bonus
} else {
0
},
frecency_boost,
git_status_boost,
distance_penalty,
combo_match_boost,
exact_match: path_match.exact || filename_match.is_some_and(|m| m.exact),
match_type: match filename_match {
Some(filename_match) if filename_match.exact => "exact_filename",
Some(_) => "fuzzy_filename",
None if path_match.exact => "exact_path",
None => "fuzzy_path",
},
};
(file, score)
})
.collect();
sort_and_paginate(results, context)
}
fn is_special_entry_point_file(filename: &str) -> bool {
matches!(
filename,
"mod.rs"
| "lib.rs"
| "main.rs"
| "index.js"
| "index.jsx"
| "index.ts"
| "index.tsx"
| "index.mjs"
| "index.cjs"
| "index.vue"
| "__init__.py"
| "__main__.py"
| "main.go"
| "main.c"
| "index.php"
| "main.rb"
| "index.rb"
)
}
pub(crate) fn score_filtered_by_frecency<'a>(
files: &FileItems<'a>,
context: &ScoringContext,
) -> (Vec<&'a FileItem>, Vec<Score>, usize) {
let score_file = |file: &'a FileItem| {
let total_frecency_score = file.access_frecency_score as i32
+ (file.modification_frecency_score as i32).saturating_mul(4);
let git_status_boost = if file.git_status.is_some_and(is_modified_status) {
total_frecency_score * 15 / 100
} else {
0
};
let current_file_penalty =
calculate_current_file_penalty(file, total_frecency_score, context);
let total = total_frecency_score
.saturating_add(git_status_boost)
.saturating_add(current_file_penalty);
let score = Score {
total,
base_score: 0,
filename_bonus: 0,
distance_penalty: 0,
special_filename_bonus: 0,
combo_match_boost: 0,
current_file_penalty,
frecency_boost: total_frecency_score,
git_status_boost,
exact_match: false,
match_type: "frecency",
};
(file, score)
};
let results: Vec<_> = match files {
FileItems::All(s) => s.par_iter().map(&score_file).collect(),
FileItems::Filtered(v) => v.iter().map(|&file| score_file(file)).collect(),
};
sort_and_paginate(results, context)
}
#[inline]
fn calculate_current_file_penalty(
file: &FileItem,
base_score: i32,
context: &ScoringContext,
) -> i32 {
let mut penalty = 0i32;
if let Some(current) = context.current_file
&& file.relative_path.as_str() == current
{
penalty -= match file.git_status {
Some(status) if is_modified_status(status) => base_score / 2,
_ => base_score,
};
}
penalty
}
#[tracing::instrument(skip_all, level = tracing::Level::DEBUG)]
fn sort_and_paginate<'a>(
mut results: Vec<(&'a FileItem, Score)>,
context: &ScoringContext,
) -> (Vec<&'a FileItem>, Vec<Score>, usize) {
let total_matched = results.len();
if total_matched == 0 {
return (vec![], vec![], 0);
}
let offset = context.pagination.offset;
let limit = if context.pagination.limit > 0 {
context.pagination.limit
} else {
total_matched
};
if offset >= total_matched {
tracing::warn!(
offset = offset,
total_matched = total_matched,
"Pagination: offset >= total_matched, returning empty"
);
return (vec![], vec![], total_matched);
}
let items_needed = offset.saturating_add(limit).min(total_matched);
let use_partial_sort = items_needed < total_matched / 2 && total_matched > 100;
if use_partial_sort {
results.select_nth_unstable_by(items_needed - 1, |a, b| {
b.1.total
.cmp(&a.1.total)
.then_with(|| b.0.modified.cmp(&a.0.modified))
});
results.truncate(items_needed);
}
sort_with_buffer(&mut results, |a, b| {
b.1.total
.cmp(&a.1.total)
.then_with(|| b.0.modified.cmp(&a.0.modified))
});
if results.len() > limit {
let page_end = std::cmp::min(offset + limit, results.len());
let page_size = page_end - offset;
results.drain(0..offset);
results.truncate(page_size);
}
let (items, scores): (Vec<&FileItem>, Vec<Score>) = results.into_iter().unzip();
(items, scores, total_matched)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::PaginationArgs;
use fff_query_parser::QueryParser;
use std::path::PathBuf;
fn create_test_file(path: &str, score: i32, modified: u64) -> (FileItem, Score) {
let file_name = path.split('/').next_back().unwrap_or(path).to_string();
let file = FileItem::new_raw(
PathBuf::from(path),
path.to_string(),
file_name,
0,
modified,
None,
false,
);
let score_obj = Score {
total: score,
base_score: score,
filename_bonus: 0,
distance_penalty: 0,
special_filename_bonus: 0,
current_file_penalty: 0,
frecency_boost: 0,
git_status_boost: 0,
exact_match: false,
match_type: "test",
combo_match_boost: 0,
};
(file, score_obj)
}
#[test]
fn test_partial_sort_descending() {
let test_data = vec![
create_test_file("file1.rs", 100, 1000),
create_test_file("file2.rs", 200, 2000),
create_test_file("file3.rs", 50, 3000),
create_test_file("file4.rs", 300, 4000),
create_test_file("file5.rs", 150, 5000),
create_test_file("file6.rs", 250, 6000),
create_test_file("file7.rs", 80, 7000),
create_test_file("file8.rs", 180, 8000),
create_test_file("file9.rs", 120, 9000),
create_test_file("file10.rs", 90, 10000),
];
let results: Vec<(&FileItem, Score)> = test_data
.iter()
.map(|(file, score)| (file, score.clone()))
.collect();
let query_str = "test";
let parser = QueryParser::default();
let query = parser.parse(query_str);
let context = ScoringContext {
query: &query,
max_threads: 1,
max_typos: 2,
current_file: None,
last_same_query_match: None,
project_path: None,
combo_boost_score_multiplier: 100,
min_combo_count: 3,
pagination: PaginationArgs {
offset: 0,
limit: 0,
},
};
let (items, scores, total) = sort_and_paginate(results.clone(), &context);
assert_eq!(total, 10);
assert_eq!(scores.len(), 10);
assert_eq!(scores[0].total, 300, "First should be highest score");
assert_eq!(scores[1].total, 250, "Second should be second highest");
assert_eq!(scores[2].total, 200, "Third should be third highest");
assert_eq!(items[0].relative_path, "file4.rs");
assert_eq!(items[1].relative_path, "file6.rs");
assert_eq!(items[2].relative_path, "file2.rs");
}
#[test]
fn test_partial_sort_with_same_scores() {
let test_data = [
create_test_file("file1.rs", 100, 5000), create_test_file("file2.rs", 100, 8000), create_test_file("file3.rs", 100, 3000), create_test_file("file4.rs", 200, 1000),
create_test_file("file5.rs", 200, 9000), ];
let results: Vec<(&FileItem, Score)> = test_data
.iter()
.map(|(file, score)| (file, score.clone()))
.collect();
let query_str = "test";
let parser = QueryParser::default();
let query = parser.parse(query_str);
let context = ScoringContext {
query: &query,
max_threads: 1,
max_typos: 2,
current_file: None,
last_same_query_match: None,
project_path: None,
combo_boost_score_multiplier: 100,
min_combo_count: 3,
pagination: PaginationArgs {
offset: 0,
limit: 0,
},
};
let (items, scores, _) = sort_and_paginate(results, &context);
assert_eq!(scores.len(), 5);
assert_eq!(scores[0].total, 200);
assert_eq!(items[0].modified, 9000, "First 200 should be newest");
assert_eq!(scores[1].total, 200);
assert_eq!(items[1].modified, 1000, "Second 200 should be older");
assert_eq!(scores[2].total, 100);
assert_eq!(items[2].modified, 8000, "First 100 should be newest");
assert_eq!(scores[3].total, 100);
assert_eq!(items[3].modified, 5000);
assert_eq!(scores[4].total, 100);
assert_eq!(items[4].modified, 3000, "Last 100 should be oldest");
}
#[test]
fn test_no_partial_sort_for_small_results() {
let test_data = [
create_test_file("file1.rs", 100, 1000),
create_test_file("file2.rs", 200, 2000),
create_test_file("file3.rs", 50, 3000),
];
let results: Vec<(&FileItem, Score)> = test_data
.iter()
.map(|(file, score)| (file, score.clone()))
.collect();
let query_str = "test";
let parser = QueryParser::default();
let query = parser.parse(query_str);
let context = ScoringContext {
query: &query,
max_threads: 1,
max_typos: 2,
current_file: None,
last_same_query_match: None,
project_path: None,
combo_boost_score_multiplier: 100,
min_combo_count: 3,
pagination: PaginationArgs {
offset: 0,
limit: 0,
},
};
let (items, scores, _) = sort_and_paginate(results, &context);
assert_eq!(scores.len(), 3);
assert_eq!(scores[0].total, 200);
assert_eq!(scores[1].total, 100);
assert_eq!(scores[2].total, 50);
assert_eq!(items[0].relative_path, "file2.rs");
assert_eq!(items[1].relative_path, "file1.rs");
assert_eq!(items[2].relative_path, "file3.rs");
}
}
#[cfg(test)]
mod multi_part_tests {
#[test]
fn test_single_path_matching() {
let path = "core_workflow_service/kafka_event_consumer/src/ai_part_extraction_request/ai_part_extraction_request_handler.rs";
let options = neo_frizbee::Config {
max_typos: Some(2),
sort: false,
..Default::default()
};
let matches = neo_frizbee::match_list("aipart", &[path], &options);
println!("'aipart' matches (max_typos=2): {:?}", matches);
assert!(!matches.is_empty(), "'aipart' should match the path");
let matches = neo_frizbee::match_list("core", &[path], &options);
println!("'core' matches (max_typos=2): {:?}", matches);
assert!(!matches.is_empty(), "'core' should match the path");
let co_options = neo_frizbee::Config {
max_typos: Some(2), ..options
};
let matches = neo_frizbee::match_list("co", &[path], &co_options);
println!("'co' matches (max_typos=2): {:?}", matches);
assert!(!matches.is_empty(), "'co' should match the path");
}
#[test]
fn test_lowercase_path_matching() {
let path = "core_workflow_service/kafka_event_consumer/src/ai_part_extraction_request/ai_part_extraction_request_handler.rs".to_lowercase();
let options = neo_frizbee::Config {
max_typos: Some(2),
sort: false,
..Default::default()
};
let matches = neo_frizbee::match_list("co", &[path.as_str()], &options);
println!("'co' matches lowercase path (max_typos=2): {:?}", matches);
assert!(!matches.is_empty(), "'co' should match the lowercase path");
let matches = neo_frizbee::match_list("core", &[path.as_str()], &options);
println!("'core' matches lowercase path (max_typos=2): {:?}", matches);
assert!(
!matches.is_empty(),
"'core' should match the lowercase path"
);
}
}