use std::collections::HashMap;
use crate::document::NodeId;
use crate::retrieval::search::{Bm25Params, STOPWORDS, extract_keywords};
use crate::utils::estimate_tokens;
use super::config::ScoringStrategyConfig;
#[derive(Debug, Clone)]
pub struct ContentChunk {
pub node_id: NodeId,
pub title: String,
pub content: String,
pub depth: usize,
}
impl ContentChunk {
#[must_use]
pub fn new(node_id: NodeId, title: String, content: String, depth: usize) -> Self {
Self {
node_id,
title,
content,
depth,
}
}
#[must_use]
pub fn token_count(&self) -> usize {
estimate_tokens(&self.content)
}
}
#[derive(Debug, Clone, Default)]
pub struct ScoreComponents {
pub keyword_score: f32,
pub bm25_score: f32,
pub depth_penalty: f32,
pub path_bonus: f32,
pub density_score: f32,
}
impl ScoreComponents {
#[must_use]
pub fn final_score(&self) -> f32 {
let score = self.keyword_score * 0.35
+ self.bm25_score * 0.25
+ self.depth_penalty * 0.15
+ self.path_bonus * 0.10
+ self.density_score * 0.15;
score.clamp(0.0, 1.0)
}
}
#[derive(Debug, Clone)]
pub struct ContentRelevance {
pub chunk: ContentChunk,
pub score: f32,
pub components: ScoreComponents,
}
impl ContentRelevance {
#[must_use]
pub fn new(chunk: ContentChunk, score: f32, components: ScoreComponents) -> Self {
Self {
chunk,
score,
components,
}
}
}
#[derive(Debug, Clone)]
pub struct ScoringContext {
pub avg_doc_len: f32,
pub doc_count: usize,
pub doc_freq: HashMap<String, usize>,
pub parent_score: Option<f32>,
}
impl Default for ScoringContext {
fn default() -> Self {
Self {
avg_doc_len: 100.0,
doc_count: 1,
doc_freq: HashMap::new(),
parent_score: None,
}
}
}
#[derive(Debug)]
pub struct RelevanceScorer {
query_keywords: Vec<String>,
strategy: ScoringStrategyConfig,
params: Bm25Params,
}
impl RelevanceScorer {
#[must_use]
pub fn new(query: &str, strategy: ScoringStrategyConfig) -> Self {
let query_keywords = extract_keywords(query);
Self {
query_keywords,
strategy,
params: Bm25Params::default(),
}
}
#[must_use]
pub fn with_keywords(keywords: Vec<String>, strategy: ScoringStrategyConfig) -> Self {
Self {
query_keywords: keywords,
strategy,
params: Bm25Params::default(),
}
}
#[must_use]
pub fn score_chunk(&self, chunk: &ContentChunk, ctx: &ScoringContext) -> ContentRelevance {
let mut components = ScoreComponents::default();
components.keyword_score = self.compute_keyword_score(&format!(
"{} {}",
chunk.title,
chunk.content
));
if matches!(
self.strategy,
ScoringStrategyConfig::KeywordWithBM25 | ScoringStrategyConfig::Hybrid
) {
components.bm25_score = self.compute_bm25_score(&chunk.content, ctx);
}
components.depth_penalty = 0.9_f32.powi(chunk.depth as i32);
components.path_bonus = ctx.parent_score.map(|s| s * 0.2).unwrap_or(0.0);
components.density_score = compute_density(&chunk.content);
let final_score = components.final_score();
ContentRelevance::new(chunk.clone(), final_score, components)
}
pub fn score_chunks<'a>(
&self,
chunks: &'a [ContentChunk],
ctx: &ScoringContext,
) -> Vec<ContentRelevance> {
chunks
.iter()
.map(|chunk| self.score_chunk(chunk, ctx))
.collect()
}
fn compute_keyword_score(&self, content: &str) -> f32 {
if self.query_keywords.is_empty() {
return 0.5; }
let content_lower = content.to_lowercase();
let content_words: std::collections::HashSet<&str> =
content_lower.split_whitespace().collect();
let matches = self
.query_keywords
.iter()
.filter(|kw| {
let kw_lower = kw.to_lowercase();
content_words.iter().any(|&w| w.contains(&kw_lower))
|| content_lower.contains(&kw_lower)
})
.count();
matches as f32 / self.query_keywords.len() as f32
}
fn compute_bm25_score(&self, content: &str, ctx: &ScoringContext) -> f32 {
if self.query_keywords.is_empty() {
return 0.0;
}
let doc_len = content.split_whitespace().count() as f32;
let mut score = 0.0;
for term in &self.query_keywords {
let term_lower = term.to_lowercase();
let tf = content.to_lowercase().matches(&term_lower).count() as f32;
if tf == 0.0 {
continue;
}
let df = ctx.doc_freq.get(&term_lower).copied().unwrap_or(1) as f32;
let idf = ((ctx.doc_count as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
let k1 = self.params.k1;
let b = self.params.b;
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * doc_len / ctx.avg_doc_len);
score += idf * numerator / denominator;
}
let max_possible_score = self.query_keywords.len() as f32 * 5.0; (score / max_possible_score).clamp(0.0, 1.0)
}
#[must_use]
pub fn keywords(&self) -> &[String] {
&self.query_keywords
}
}
fn compute_density(content: &str) -> f32 {
let words: Vec<&str> = content.split_whitespace().collect();
if words.is_empty() {
return 0.0;
}
let stopword_count = words
.iter()
.filter(|w| STOPWORDS.contains(&w.to_lowercase().as_str()))
.count();
let stopword_ratio = stopword_count as f32 / words.len() as f32;
let entity_count = words
.iter()
.filter(|w| w.chars().any(|c| c.is_numeric() || c.is_uppercase()))
.count();
let entity_ratio = entity_count as f32 / words.len() as f32;
(1.0 - stopword_ratio) * 0.7 + entity_ratio * 0.3
}
#[cfg(test)]
mod tests {
use super::*;
use indextree::Arena;
fn make_test_node_id() -> NodeId {
let mut arena = Arena::new();
let node = crate::document::TreeNode {
title: "Test".to_string(),
structure: String::new(),
content: String::new(),
summary: String::new(),
depth: 0,
start_index: 0,
end_index: 0,
start_page: None,
end_page: None,
node_id: None,
physical_index: None,
token_count: None,
references: Vec::new(),
};
NodeId(arena.new_node(node))
}
#[test]
fn test_keyword_extraction() {
let keywords = extract_keywords("What is the architecture of vectorless?");
assert!(keywords.contains(&"architecture".to_string()));
assert!(keywords.contains(&"vectorless".to_string()));
assert!(!keywords.contains(&"what".to_string())); assert!(!keywords.contains(&"the".to_string())); }
#[test]
fn test_keyword_score() {
let scorer = RelevanceScorer::new(
"vectorless architecture",
ScoringStrategyConfig::KeywordOnly,
);
let chunk = ContentChunk::new(
make_test_node_id(),
"Test".to_string(),
"Vectorless has a unique architecture for document retrieval.".to_string(),
0,
);
let ctx = ScoringContext::default();
let score = scorer.compute_keyword_score(&chunk.content);
assert!(score > 0.5); }
#[test]
fn test_density_score() {
let high_density = "Rust 1.85+ requires Cargo.toml configuration with [dependencies]";
let score = compute_density(high_density);
assert!(score > 0.5);
let low_density = "This is a test of the system with some words in it";
let score = compute_density(low_density);
assert!(score < 0.7);
}
#[test]
fn test_depth_penalty() {
let shallow = ContentChunk::new(
make_test_node_id(),
"Test".to_string(),
"Content".to_string(),
0,
);
let deep = ContentChunk::new(
make_test_node_id(),
"Test".to_string(),
"Content".to_string(),
5,
);
let scorer = RelevanceScorer::new("test", ScoringStrategyConfig::KeywordOnly);
let ctx = ScoringContext::default();
let shallow_score = scorer.score_chunk(&shallow, &ctx);
let deep_score = scorer.score_chunk(&deep, &ctx);
assert!(shallow_score.components.depth_penalty > deep_score.components.depth_penalty);
}
#[test]
fn test_score_components_final_score() {
let components = ScoreComponents {
keyword_score: 0.8,
bm25_score: 0.6,
depth_penalty: 0.9,
path_bonus: 0.1,
density_score: 0.5,
};
let final_score = components.final_score();
assert!(final_score > 0.0 && final_score <= 1.0);
}
}