use std::collections::HashMap;
use crate::document::{DocumentTree, NodeId};
use super::bm25::Bm25Params;
pub use super::bm25::extract_keywords;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ScoringStrategy {
KeywordOnly,
#[default]
BM25,
Hybrid,
}
#[derive(Debug, Clone)]
pub struct ScoringContext {
pub query_terms: Vec<String>,
pub title_weight: f32,
pub summary_weight: f32,
pub content_weight: f32,
pub depth_penalty: f32,
pub strategy: ScoringStrategy,
pub bm25_params: Bm25Params,
pub avg_doc_len: f32,
pub doc_freq: HashMap<String, usize>,
pub doc_count: usize,
}
impl Default for ScoringContext {
fn default() -> Self {
Self {
query_terms: Vec::new(),
title_weight: 2.0,
summary_weight: 1.5,
content_weight: 1.0,
depth_penalty: 0.1,
strategy: ScoringStrategy::default(),
bm25_params: Bm25Params::default(),
avg_doc_len: 100.0,
doc_freq: HashMap::new(),
doc_count: 1,
}
}
}
impl ScoringContext {
pub fn new(query: &str) -> Self {
Self {
query_terms: extract_keywords(query),
..Default::default()
}
}
pub fn with_strategy(query: &str, strategy: ScoringStrategy) -> Self {
Self {
query_terms: extract_keywords(query),
strategy,
..Default::default()
}
}
pub fn with_bm25_params(mut self, params: Bm25Params) -> Self {
self.bm25_params = params;
self
}
pub fn with_doc_stats(
mut self,
doc_count: usize,
avg_doc_len: f32,
doc_freq: HashMap<String, usize>,
) -> Self {
self.doc_count = doc_count.max(1);
self.avg_doc_len = avg_doc_len.max(1.0);
self.doc_freq = doc_freq;
self
}
fn term_frequency(&self, text: &str, term: &str) -> f32 {
text.to_lowercase().matches(term).count() as f32
}
fn idf(&self, term: &str) -> f32 {
let df = self.doc_freq.get(term).copied().unwrap_or(1) as f32;
let n = self.doc_count as f32;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
fn bm25_field_score(&self, text: &str) -> f32 {
if self.query_terms.is_empty() {
return 0.0;
}
let doc_len = text.split_whitespace().count() as f32;
let k1 = self.bm25_params.k1;
let b = self.bm25_params.b;
let mut score = 0.0;
for term in &self.query_terms {
let tf = self.term_frequency(text, term);
if tf == 0.0 {
continue;
}
let idf = self.idf(term);
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * doc_len / self.avg_doc_len);
score += idf * numerator / denominator;
}
score
}
fn keyword_overlap(&self, text: &str) -> f32 {
if self.query_terms.is_empty() {
return 0.0;
}
let text_lower = text.to_lowercase();
let matches = self
.query_terms
.iter()
.filter(|term| text_lower.contains(term.as_str()))
.count();
matches as f32 / self.query_terms.len() as f32
}
pub fn quick_score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 {
if let Some(node) = tree.get(node_id) {
let title_score = self.keyword_overlap(&node.title);
let summary_score = self.keyword_overlap(&node.summary);
let content_score = self.keyword_overlap(&node.content);
let base_score = (title_score * self.title_weight
+ summary_score * self.summary_weight
+ content_score * self.content_weight)
/ (self.title_weight + self.summary_weight + self.content_weight);
let depth_factor = 1.0 - (node.depth as f32 * self.depth_penalty).min(0.5);
base_score * depth_factor
} else {
0.0
}
}
pub fn bm25_score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 {
if let Some(node) = tree.get(node_id) {
let title_score = self.bm25_field_score(&node.title) * self.title_weight;
let summary_score = self.bm25_field_score(&node.summary) * self.summary_weight;
let content_score = self.bm25_field_score(&node.content) * self.content_weight;
let total_score = title_score + summary_score + content_score;
let normalized = (total_score / 3.0).tanh();
let depth_factor = 1.0 - (node.depth as f32 * self.depth_penalty).min(0.5);
normalized * depth_factor
} else {
0.0
}
}
pub fn hybrid_score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 {
let keyword = self.quick_score(tree, node_id);
let bm25 = self.bm25_score(tree, node_id);
keyword * 0.4 + bm25 * 0.6
}
pub fn score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 {
match self.strategy {
ScoringStrategy::KeywordOnly => self.quick_score(tree, node_id),
ScoringStrategy::BM25 => self.bm25_score(tree, node_id),
ScoringStrategy::Hybrid => self.hybrid_score(tree, node_id),
}
}
}
pub struct NodeScorer {
context: ScoringContext,
}
impl NodeScorer {
pub fn new(context: ScoringContext) -> Self {
Self { context }
}
pub fn for_query(query: &str) -> Self {
Self::new(ScoringContext::new(query))
}
pub fn with_strategy(query: &str, strategy: ScoringStrategy) -> Self {
Self::new(ScoringContext::with_strategy(query, strategy))
}
pub fn context(&self) -> &ScoringContext {
&self.context
}
pub fn context_mut(&mut self) -> &mut ScoringContext {
&mut self.context
}
pub fn score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 {
self.context.score(tree, node_id)
}
pub fn score_and_sort(&self, tree: &DocumentTree, node_ids: &[NodeId]) -> Vec<(NodeId, f32)> {
let mut scored: Vec<_> = node_ids
.iter()
.map(|&id| (id, self.score(tree, id)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
}
pub fn chunk_score(&self, chunk: &str) -> f32 {
self.context.keyword_overlap(chunk)
}
pub fn node_score(&self, tree: &DocumentTree, node_id: NodeId, chunk_size: usize) -> f32 {
if let Some(node) = tree.get(node_id) {
let content = format!("{} {} {}", node.title, node.summary, node.content);
let chunks: Vec<&str> = content
.as_bytes()
.chunks(chunk_size)
.map(|b| std::str::from_utf8(b).unwrap_or(""))
.collect();
if chunks.is_empty() {
return 0.0;
}
let total_score: f32 = chunks.iter().map(|c| self.chunk_score(c)).sum();
let n = chunks.len() as f32;
total_score / (n + 1.0).sqrt()
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_keywords() {
let keywords = extract_keywords("What is the architecture of vectorless?");
assert!(keywords.contains(&"architecture".to_string()));
assert!(keywords.contains(&"vectorless".to_string()));
assert!(!keywords.contains(&"what".to_string())); assert!(!keywords.contains(&"the".to_string())); }
#[test]
fn test_keyword_overlap() {
let ctx = ScoringContext::new("vectorless architecture");
let text = "Vectorless has a unique architecture for document retrieval.";
let score = ctx.keyword_overlap(text);
assert!(score > 0.5); }
#[test]
fn test_bm25_scoring() {
let ctx = ScoringContext::with_strategy("rust cargo", ScoringStrategy::BM25);
let text = "Rust is a programming language. Cargo is its package manager. Rust Rust Rust.";
let score = ctx.bm25_field_score(text);
assert!(score > 0.0);
}
#[test]
fn test_hybrid_scoring() {
let ctx = ScoringContext::with_strategy("test query", ScoringStrategy::Hybrid);
let keyword_score = ctx.keyword_overlap("test query content");
let bm25_score = ctx.bm25_field_score("test query content");
let hybrid = ctx.keyword_overlap("test query content") * 0.4
+ ctx.bm25_field_score("test query content") * 0.6;
assert!(hybrid > 0.0);
}
#[test]
fn test_scorer_creation() {
let scorer = NodeScorer::for_query("test query");
assert!(!scorer.context().query_terms.is_empty());
}
#[test]
fn test_scorer_with_strategy() {
let scorer = NodeScorer::with_strategy("test", ScoringStrategy::BM25);
assert_eq!(scorer.context().strategy, ScoringStrategy::BM25);
}
}