use std::collections::HashMap;
use bm25::{
Embedder, EmbedderBuilder, Embedding, Language, Scorer, ScoredDocument,
DefaultTokenizer, Tokenizer,
};
#[derive(Debug, Clone, Copy)]
pub struct FieldWeights {
pub title: f32,
pub summary: f32,
pub content: f32,
}
impl Default for FieldWeights {
fn default() -> Self {
Self {
title: 2.0,
summary: 1.5,
content: 1.0,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
pub k1: f32,
pub b: f32,
pub avgdl: f32,
}
impl Default for Bm25Params {
fn default() -> Self {
Self {
k1: 1.2,
b: 0.75,
avgdl: 100.0,
}
}
}
#[derive(Debug, Clone)]
pub struct FieldDocument<K> {
pub id: K,
pub title: String,
pub summary: String,
pub content: String,
}
impl<K> FieldDocument<K> {
pub fn new(id: K, title: String, summary: String, content: String) -> Self {
Self { id, title, summary, content }
}
fn combined_text(&self) -> String {
format!("{} {} {}", self.title, self.summary, self.content)
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct FieldKey<K> {
doc_id: K,
field: Field,
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
enum Field {
Title,
Summary,
Content,
}
pub struct Bm25Engine<K> {
embedder: Embedder,
scorer: Scorer<K>,
title_scorer: Scorer<K>,
summary_scorer: Scorer<K>,
content_scorer: Scorer<K>,
weights: FieldWeights,
doc_count: usize,
fitted: bool,
}
impl<K: std::hash::Hash + Eq + Clone + std::fmt::Debug> Bm25Engine<K> {
pub fn new() -> Self {
Self::with_params(Bm25Params::default())
}
pub fn with_params(params: Bm25Params) -> Self {
let embedder = EmbedderBuilder::with_avgdl(params.avgdl)
.k1(params.k1)
.b(params.b)
.language_mode(Language::English)
.build();
Self {
embedder,
scorer: Scorer::new(),
title_scorer: Scorer::new(),
summary_scorer: Scorer::new(),
content_scorer: Scorer::new(),
weights: FieldWeights::default(),
doc_count: 0,
fitted: false,
}
}
pub fn fit_to_corpus(documents: &[FieldDocument<K>]) -> Self {
let corpus: Vec<String> = documents.iter()
.map(|d| d.combined_text())
.collect();
let corpus_refs: Vec<&str> = corpus.iter().map(|s| s.as_str()).collect();
let embedder = EmbedderBuilder::with_fit_to_corpus(Language::English, &corpus_refs)
.build();
let mut engine = Self {
embedder,
scorer: Scorer::new(),
title_scorer: Scorer::new(),
summary_scorer: Scorer::new(),
content_scorer: Scorer::new(),
weights: FieldWeights::default(),
doc_count: 0,
fitted: true,
};
for doc in documents {
engine.upsert(doc);
}
engine
}
pub fn with_weights(mut self, weights: FieldWeights) -> Self {
self.weights = weights;
self
}
pub fn with_language(mut self, language: Language) -> Self {
self.embedder = EmbedderBuilder::with_avgdl(self.embedder.avgdl())
.language_mode(language)
.build();
self
}
pub fn avgdl(&self) -> f32 {
self.embedder.avgdl()
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn upsert(&mut self, document: &FieldDocument<K>) {
let id = &document.id;
let title_emb = self.embedder.embed(&document.title);
let summary_emb = self.embedder.embed(&document.summary);
let content_emb = self.embedder.embed(&document.content);
self.title_scorer.upsert(id, title_emb);
self.summary_scorer.upsert(id, summary_emb);
self.content_scorer.upsert(id, content_emb);
let combined = self.embedder.embed(&document.combined_text());
self.scorer.upsert(id, combined);
self.doc_count += 1;
}
pub fn remove(&mut self, id: &K) {
self.scorer.remove(id);
self.title_scorer.remove(id);
self.summary_scorer.remove(id);
self.content_scorer.remove(id);
self.doc_count = self.doc_count.saturating_sub(1);
}
pub fn len(&self) -> usize {
self.doc_count
}
pub fn is_empty(&self) -> bool {
self.doc_count == 0
}
pub fn score(&self, id: &K, query: &str) -> Option<f32> {
let query_emb = self.embedder.embed(query);
let title_score = self.title_scorer.score(id, &query_emb)?;
let summary_score = self.summary_scorer.score(id, &query_emb)?;
let content_score = self.content_scorer.score(id, &query_emb)?;
let total_weight = self.weights.title + self.weights.summary + self.weights.content;
let weighted_score = (title_score * self.weights.title
+ summary_score * self.weights.summary
+ content_score * self.weights.content) / total_weight;
Some(weighted_score)
}
pub fn search(&self, query: &str, limit: usize) -> Vec<ScoredDocument<K>> {
let query_emb = self.embedder.embed(query);
self.scorer.matches(&query_emb).into_iter().take(limit).collect()
}
pub fn search_weighted(&self, query: &str, limit: usize) -> Vec<(K, f32)> {
let query_emb = self.embedder.embed(query);
let all_results = self.scorer.matches(&query_emb);
let mut scored: Vec<(K, f32)> = all_results
.into_iter()
.filter_map(|scored_doc| {
let id = scored_doc.id;
let title_score = self.title_scorer.score(&id, &query_emb)?;
let summary_score = self.summary_scorer.score(&id, &query_emb)?;
let content_score = self.content_scorer.score(&id, &query_emb)?;
let total_weight = self.weights.title + self.weights.summary + self.weights.content;
let weighted_score = (title_score * self.weights.title
+ summary_score * self.weights.summary
+ content_score * self.weights.content) / total_weight;
Some((id, weighted_score))
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
scored
}
pub fn tokenize(&self, text: &str) -> Vec<String> {
let tokenizer = DefaultTokenizer::builder()
.language_mode(Language::English)
.normalization(true)
.stopwords(true)
.stemming(true)
.build();
tokenizer.tokenize(text)
}
pub fn embedder(&self) -> &Embedder {
&self.embedder
}
pub fn embedder_mut(&mut self) -> &mut Embedder {
&mut self.embedder
}
}
impl<K: std::hash::Hash + Eq + Clone + std::fmt::Debug> Default for Bm25Engine<K> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ExpandedQuery {
pub original: String,
pub expansions: Vec<String>,
pub combined: String,
}
impl ExpandedQuery {
pub fn new(original: String, expansions: Vec<String>) -> Self {
let combined = format!("{} {}", original, expansions.join(" "));
Self { original, expansions, combined }
}
}
#[async_trait::async_trait]
pub trait QueryExpander: Send + Sync {
async fn expand(&self, query: &str) -> ExpandedQuery;
}
pub const STOPWORDS: &[&str] = &[
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
"have", "has", "had", "do", "does", "did", "will", "would", "could",
"should", "may", "might", "must", "shall", "can", "need", "dare",
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
"from", "as", "into", "through", "during", "before", "after", "above",
"below", "between", "under", "again", "further", "then", "once",
"here", "there", "when", "where", "why", "how", "all", "each", "few",
"more", "most", "other", "some", "such", "no", "nor", "not", "only",
"own", "same", "so", "than", "too", "very", "just", "and", "but",
"if", "or", "because", "until", "while", "about", "what", "which",
"who", "whom", "this", "that", "these", "those", "i", "me", "my",
"myself", "we", "our", "ours", "ourselves", "you", "your", "yours",
"yourself", "yourselves", "he", "him", "his", "himself", "she", "her",
"hers", "herself", "it", "its", "itself", "they", "them", "their",
"theirs", "themselves",
];
#[must_use]
pub fn extract_keywords(query: &str) -> Vec<String> {
query
.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| {
let s = *s;
!s.is_empty() && s.len() > 1 && !STOPWORDS.contains(&s)
})
.map(String::from)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_engine_creation() {
let engine: Bm25Engine<u32> = Bm25Engine::new();
assert!(engine.is_empty());
assert!(!engine.is_fitted());
}
#[test]
fn test_bm25_engine_fit_to_corpus() {
let docs = vec![
FieldDocument::new(1u32, "Rust Programming".to_string(), "About Rust".to_string(), "Rust is a systems programming language.".to_string()),
FieldDocument::new(2u32, "Python Guide".to_string(), "About Python".to_string(), "Python is a scripting language.".to_string()),
];
let engine = Bm25Engine::fit_to_corpus(&docs);
assert!(engine.is_fitted());
assert_eq!(engine.len(), 2);
}
#[test]
fn test_bm25_search() {
let docs = vec![
FieldDocument::new(1u32, "Rust Programming".to_string(), "About Rust".to_string(), "Rust is a systems programming language with memory safety.".to_string()),
FieldDocument::new(2u32, "Python Guide".to_string(), "About Python".to_string(), "Python is a scripting language for data science.".to_string()),
FieldDocument::new(3u32, "Rust Memory Safety".to_string(), "Memory in Rust".to_string(), "Rust provides guaranteed memory safety without garbage collection.".to_string()),
];
let engine = Bm25Engine::fit_to_corpus(&docs);
let results = engine.search("rust memory", 10);
assert!(!results.is_empty());
assert!(results.iter().any(|r| r.id == 1 || r.id == 3));
}
#[test]
fn test_bm25_weighted_search() {
let docs = vec![
FieldDocument::new(1u32, "Rust Programming".to_string(), "About memory safety".to_string(), "Content about other things.".to_string()),
FieldDocument::new(2u32, "Other Language".to_string(), "About other things".to_string(), "Rust memory safety is important.".to_string()),
];
let engine = Bm25Engine::fit_to_corpus(&docs)
.with_weights(FieldWeights {
title: 3.0,
summary: 2.0,
content: 1.0,
});
let results = engine.search_weighted("rust", 10);
assert_eq!(results.first().map(|(id, _)| *id), Some(1u32));
}
#[test]
fn test_bm25_score() {
let docs = vec![
FieldDocument::new(1u32, "Rust Programming".to_string(), "About Rust".to_string(), "Rust is a systems programming language.".to_string()),
];
let engine = Bm25Engine::fit_to_corpus(&docs);
let score = engine.score(&1u32, "rust programming");
assert!(score.is_some());
assert!(score.unwrap() > 0.0);
}
#[test]
fn test_bm25_tokenize() {
let engine: Bm25Engine<u32> = Bm25Engine::new();
let tokens = engine.tokenize("What is the Rust programming language?");
assert!(tokens.contains(&"rust".to_string()));
assert!(tokens.contains(&"program".to_string())); assert!(!tokens.contains(&"what".to_string())); assert!(!tokens.contains(&"the".to_string())); }
#[test]
fn test_bm25_remove() {
let docs = vec![
FieldDocument::new(1u32, "Rust".to_string(), "About Rust".to_string(), "Rust content.".to_string()),
];
let mut engine = Bm25Engine::fit_to_corpus(&docs);
assert_eq!(engine.len(), 1);
engine.remove(&1u32);
assert!(engine.is_empty());
}
#[test]
fn test_field_weights_default() {
let weights = FieldWeights::default();
assert!((weights.title - 2.0).abs() < f32::EPSILON);
assert!((weights.summary - 1.5).abs() < f32::EPSILON);
assert!((weights.content - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_bm25_params_default() {
let params = Bm25Params::default();
assert!((params.k1 - 1.2).abs() < f32::EPSILON);
assert!((params.b - 0.75).abs() < f32::EPSILON);
assert!((params.avgdl - 100.0).abs() < f32::EPSILON);
}
#[test]
fn test_expanded_query() {
let expanded = ExpandedQuery::new(
"rust".to_string(),
vec!["programming".to_string(), "language".to_string()],
);
assert_eq!(expanded.original, "rust");
assert_eq!(expanded.expansions.len(), 2);
assert_eq!(expanded.combined, "rust programming language");
}
}