use crate::store::{CodeBlock, DocumentBlock, TextBlock};
use std::collections::HashMap;
pub struct Reranker;
impl Reranker {
pub fn rerank_code_blocks(mut blocks: Vec<CodeBlock>, query: &str) -> Vec<CodeBlock> {
if blocks.is_empty() {
return blocks;
}
let query_lower = query.to_lowercase();
for block in &mut blocks {
let mut score = block.distance.unwrap_or(1.0);
score *= Self::text_match_factor(&block.content, &query_lower);
score *= Self::symbol_match_factor(&block.symbols, &query_lower);
score *= Self::path_relevance_factor(&block.path, &query_lower);
score *= Self::content_length_factor(&block.content);
block.distance = Some(score);
}
blocks.sort_by(|a, b| {
let score_a = a.distance.unwrap_or(1.0);
let score_b = b.distance.unwrap_or(1.0);
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
blocks
}
pub fn rerank_document_blocks(
mut blocks: Vec<DocumentBlock>,
query: &str,
) -> Vec<DocumentBlock> {
if blocks.is_empty() {
return blocks;
}
let query_lower = query.to_lowercase();
for block in &mut blocks {
let mut score = block.distance.unwrap_or(1.0);
score *= Self::title_match_factor(&block.title, &query_lower);
score *= Self::text_match_factor(&block.content, &query_lower);
score *= Self::path_relevance_factor(&block.path, &query_lower);
score *= Self::header_level_factor(block.level);
block.distance = Some(score);
}
blocks.sort_by(|a, b| {
let score_a = a.distance.unwrap_or(1.0);
let score_b = b.distance.unwrap_or(1.0);
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
blocks
}
pub fn rerank_text_blocks(mut blocks: Vec<TextBlock>, query: &str) -> Vec<TextBlock> {
if blocks.is_empty() {
return blocks;
}
let query_lower = query.to_lowercase();
for block in &mut blocks {
let mut score = block.distance.unwrap_or(1.0);
score *= Self::text_match_factor(&block.content, &query_lower);
score *= Self::path_relevance_factor(&block.path, &query_lower);
score *= Self::content_length_factor(&block.content);
block.distance = Some(score);
}
blocks.sort_by(|a, b| {
let score_a = a.distance.unwrap_or(1.0);
let score_b = b.distance.unwrap_or(1.0);
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
blocks
}
fn text_match_factor(content: &str, query: &str) -> f32 {
let content_lower = content.to_lowercase();
if content_lower.contains(query) {
let word_count = query.split_whitespace().count();
return match word_count {
1 => 0.7, 2..=3 => 0.5, _ => 0.6, };
}
let query_words: Vec<&str> = query.split_whitespace().collect();
let content_words: Vec<&str> = content_lower.split_whitespace().collect();
let mut matches = 0;
for query_word in &query_words {
if content_words
.iter()
.any(|&word| word.contains(query_word) || query_word.contains(word))
{
matches += 1;
}
}
if matches > 0 {
let match_ratio = matches as f32 / query_words.len() as f32;
return 0.8 + (match_ratio * 0.15); }
1.0 }
fn title_match_factor(title: &str, query: &str) -> f32 {
let title_lower = title.to_lowercase();
if title_lower == query {
return 0.4; }
if title_lower.contains(query) {
return 0.5;
}
if query.contains(&title_lower) && title_lower.len() > 2 {
return 0.6;
}
let query_words: Vec<&str> = query.split_whitespace().collect();
let title_words: Vec<&str> = title_lower.split_whitespace().collect();
let mut matches = 0;
for query_word in &query_words {
if title_words
.iter()
.any(|&word| word.contains(query_word) || query_word.contains(word))
{
matches += 1;
}
}
if matches > 0 {
let match_ratio = matches as f32 / query_words.len() as f32;
return 0.6 + (match_ratio * 0.2); }
1.0 }
fn symbol_match_factor(symbols: &[String], query: &str) -> f32 {
for symbol in symbols {
let symbol_lower = symbol.to_lowercase();
if symbol_lower.contains(&query.to_lowercase())
|| query.to_lowercase().contains(&symbol_lower)
{
return 0.6; }
}
1.0
}
fn path_relevance_factor(path: &str, query: &str) -> f32 {
let path_lower = path.to_lowercase();
let query_lower = query.to_lowercase();
if let Some(filename) = path_lower.split('/').next_back() {
if filename.contains(&query_lower) {
return 0.75;
}
}
if path_lower.contains(&query_lower) {
return 0.85;
}
1.0
}
fn content_length_factor(content: &str) -> f32 {
let length = content.len();
match length {
0..=50 => 0.9, 51..=500 => 0.95, 501..=2000 => 1.0, 2001..=5000 => 0.98, _ => 0.95, }
}
fn header_level_factor(level: usize) -> f32 {
match level {
1 => 0.9, 2 => 0.85, 3 => 0.9, 4 => 0.95, _ => 1.0, }
}
pub fn tf_idf_boost(blocks: &mut [CodeBlock], query: &str) {
let query_lower = query.to_lowercase();
let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
let mut doc_freq: HashMap<String, usize> = HashMap::new();
let total_docs = blocks.len();
for block in blocks.iter() {
let content_lower = block.content.to_lowercase();
let mut seen_terms = std::collections::HashSet::new();
for term in &query_terms {
if content_lower.contains(term) && !seen_terms.contains(term) {
*doc_freq.entry(term.to_string()).or_insert(0) += 1;
seen_terms.insert(term);
}
}
}
for block in blocks.iter_mut() {
let content_lower = block.content.to_lowercase();
let mut tf_idf_score = 0.0;
for term in &query_terms {
let tf = content_lower.matches(term).count() as f32;
let df = doc_freq.get(*term).unwrap_or(&1);
let idf = (total_docs as f32 / *df as f32).ln();
tf_idf_score += tf * idf;
}
if tf_idf_score > 0.0 {
let boost_factor = (1.0 - (tf_idf_score / 10.0).min(0.3)).max(0.5);
if let Some(distance) = block.distance {
block.distance = Some(distance * boost_factor);
}
}
}
}
}