use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::Instant;
use crate::search::tokenizer;
mod index;
mod query;
#[cfg(test)]
mod tests;
pub use query::{QueryIntent, classify_query, needs_code_search};
#[derive(Debug, Clone)]
pub(crate) struct Document {
pub(super) rel_path: String,
pub(super) abs_path: PathBuf,
pub(super) tf: HashMap<String, u32>,
pub(super) token_count: u32,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub rel_path: String,
pub abs_path: PathBuf,
pub score: f64,
pub matching_lines: Vec<usize>,
pub snippets: Vec<String>,
}
pub enum IndexUpdate {
Upsert(PathBuf),
Remove(PathBuf),
}
pub struct Bm25Index {
pub(super) documents: HashMap<PathBuf, Document>,
pub(super) doc_freq: HashMap<String, u32>,
pub(super) total_tokens: u64,
pub(super) root: PathBuf,
}
impl Default for Bm25Index {
fn default() -> Self {
Self::new()
}
}
impl Bm25Index {
pub fn new() -> Self {
Self {
documents: HashMap::new(),
doc_freq: HashMap::new(),
total_tokens: 0,
root: PathBuf::new(),
}
}
pub fn build(&mut self, files: &[PathBuf], root: &Path) {
use rayon::prelude::*;
let start = Instant::now();
self.documents.clear();
self.doc_freq.clear();
self.total_tokens = 0;
self.root = root.to_path_buf();
let file_data: Vec<(PathBuf, String, HashMap<String, u32>, u32)> = files
.par_iter()
.filter_map(|file| {
let content = std::fs::read_to_string(file).ok()?;
let rel_path = file
.strip_prefix(root)
.unwrap_or(file)
.to_string_lossy()
.to_string();
let mut tf = tokenizer::tokenize_code(&content);
for token in tokenizer::tokenize_path(&rel_path) {
*tf.entry(token).or_default() += 1;
}
let token_count: u32 = tf.values().sum();
Some((file.to_path_buf(), rel_path, tf, token_count))
})
.collect();
for (abs_path, rel_path, tf, token_count) in file_data {
self.total_tokens += token_count as u64;
for key in tf.keys() {
*self.doc_freq.entry(key.clone()).or_default() += 1;
}
self.documents.insert(
abs_path.clone(),
Document {
rel_path,
abs_path,
tf,
token_count,
},
);
}
tracing::info!(
"BM25 full build: {} docs, {} terms in {:?}",
self.documents.len(),
self.doc_freq.len(),
start.elapsed(),
);
}
pub fn update(&mut self, updates: &[IndexUpdate], root: &Path) {
if updates.is_empty() {
return;
}
let start = Instant::now();
self.root = root.to_path_buf();
let mut added = 0u32;
let mut removed = 0u32;
for update in updates {
match update {
IndexUpdate::Upsert(path) => {
self.remove_file(path);
self.index_file(path);
added += 1;
}
IndexUpdate::Remove(path) => {
self.remove_file(path);
removed += 1;
}
}
}
tracing::info!(
"BM25 incremental: +{added} -{removed} files in {:?} (total: {} docs, {} terms)",
start.elapsed(),
self.documents.len(),
self.doc_freq.len(),
);
}
pub fn retain_files(&mut self, current_files: &std::collections::HashSet<PathBuf>) {
let to_remove: Vec<PathBuf> = self
.documents
.keys()
.filter(|k| !current_files.contains(*k))
.cloned()
.collect();
for path in &to_remove {
self.remove_file(path);
}
if !to_remove.is_empty() {
tracing::debug!("BM25: pruned {} deleted files", to_remove.len());
}
}
pub fn contains(&self, file: &Path) -> bool {
self.documents.contains_key(file)
}
pub(super) fn avg_doc_len(&self) -> f64 {
let n = self.documents.len() as f64;
if n > 0.0 {
self.total_tokens as f64 / n
} else {
1.0
}
}
pub fn search(&self, query: &str, max_results: usize) -> Vec<SearchResult> {
let scored = self.rank(query, max_results);
let query_tokens = tokenizer::tokenize_query(query);
scored
.into_iter()
.map(|(rel_path, abs_path, score)| {
let (matching_lines, snippets) =
extract_snippets_from_disk(&abs_path, &query_tokens);
SearchResult {
rel_path,
abs_path,
score,
matching_lines,
snippets,
}
})
.collect()
}
pub fn search_fast(&self, query: &str, max_results: usize) -> Vec<(String, f64)> {
self.rank(query, max_results)
.into_iter()
.map(|(rel, _, score)| (rel, score))
.collect()
}
pub fn relevant_files(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
self.search(query, top_k)
.into_iter()
.map(|r| (r.rel_path, r.score))
.collect()
}
pub fn relevant_files_dynamic(
&self,
query: &str,
max_k: usize,
conversation_files: &[String],
) -> Vec<(String, f64)> {
let mut scores = self.search_fast(query, max_k * 2);
if scores.is_empty() {
return Vec::new();
}
for (path, score) in &mut scores {
if conversation_files
.iter()
.any(|cf| path.contains(cf) || cf.contains(path.as_str()))
{
*score *= 1.5;
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_score = scores[0].1;
let mean_score: f64 = scores.iter().map(|s| s.1).sum::<f64>() / scores.len() as f64;
let gap_ratio = top_score / (mean_score + 0.01);
if gap_ratio < 1.5 {
return Vec::new();
}
let mut result = vec![scores[0].clone()];
for window in scores.windows(2) {
let drop = window[0].1 / (window[1].1 + 0.01);
if drop > 2.0 {
break;
}
result.push(window[1].clone());
}
result.truncate(max_k);
result
}
pub fn indexed_files(&self) -> Vec<&str> {
self.documents
.values()
.map(|d| d.rel_path.as_str())
.collect()
}
pub fn doc_count(&self) -> usize {
self.documents.len()
}
pub fn term_count(&self) -> usize {
self.doc_freq.len()
}
pub fn format_results(results: &[SearchResult], include_snippets: bool) -> String {
if results.is_empty() {
return "No relevant results found.".to_string();
}
let mut output = String::new();
for (i, r) in results.iter().enumerate() {
let display_path = if r.abs_path.as_os_str().is_empty() {
r.rel_path.clone()
} else {
r.abs_path.display().to_string()
};
let line_hint = if r.matching_lines.is_empty() {
String::new()
} else {
let lines: Vec<String> = r.matching_lines.iter().map(|l| l.to_string()).collect();
format!(" [lines: {}]", lines.join(", "))
};
output.push_str(&format!(
"{}. {}{} (score: {:.2})\n",
i + 1,
display_path,
line_hint,
r.score,
));
if include_snippets {
for snippet in &r.snippets {
output.push_str(&format!(" {snippet}\n"));
}
}
}
output
}
}
fn extract_snippets_from_disk(path: &Path, query_tokens: &[String]) -> (Vec<usize>, Vec<String>) {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(_) => return (Vec::new(), Vec::new()),
};
let mut line_scores: Vec<(usize, f64, &str)> = Vec::new();
for (line_num, line) in content.lines().enumerate() {
let line_lower = line.to_lowercase();
let mut line_score = 0.0;
for token in query_tokens {
if line_lower.contains(token.as_str()) {
line_score += 1.0;
}
}
if line_score > 0.0 {
line_scores.push((line_num + 1, line_score, line));
}
}
line_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
line_scores.truncate(5);
line_scores.sort_by_key(|&(ln, _, _)| ln);
let lines = line_scores.iter().map(|&(ln, _, _)| ln).collect();
let snippets = line_scores
.iter()
.map(|&(ln, _, line)| {
let trimmed = line.trim();
if trimmed.len() > 120 {
format!("L{ln}: {}...", crate::util::truncate_bytes(trimmed, 120))
} else {
format!("L{ln}: {trimmed}")
}
})
.collect();
(lines, snippets)
}