use bm25::{
DefaultTokenizer, Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer, Tokenizer,
};
#[derive(Debug, Clone, Copy)]
pub struct FieldWeights {
pub title: f32,
pub summary: f32,
pub content: f32,
}
impl Default for FieldWeights {
fn default() -> Self {
Self {
title: 2.0,
summary: 1.5,
content: 1.0,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
pub k1: f32,
pub b: f32,
pub avgdl: f32,
}
impl Default for Bm25Params {
fn default() -> Self {
Self {
k1: 1.2,
b: 0.75,
avgdl: 100.0,
}
}
}
#[derive(Debug, Clone)]
pub struct FieldDocument<K> {
pub id: K,
pub title: String,
pub summary: String,
pub content: String,
}
impl<K> FieldDocument<K> {
pub fn new(id: K, title: String, summary: String, content: String) -> Self {
Self {
id,
title,
summary,
content,
}
}
fn combined_text(&self) -> String {
format!("{} {} {}", self.title, self.summary, self.content)
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct FieldKey<K> {
doc_id: K,
field: Field,
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
enum Field {
Title,
Summary,
Content,
}
pub struct Bm25Engine<K> {
embedder: Embedder,
scorer: Scorer<K>,
title_scorer: Scorer<K>,
summary_scorer: Scorer<K>,
content_scorer: Scorer<K>,
weights: FieldWeights,
doc_count: usize,
fitted: bool,
}
impl<K: std::hash::Hash + Eq + Clone + std::fmt::Debug> Bm25Engine<K> {
pub fn new() -> Self {
Self::with_params(Bm25Params::default())
}
pub fn with_params(params: Bm25Params) -> Self {
let embedder = EmbedderBuilder::with_avgdl(params.avgdl)
.k1(params.k1)
.b(params.b)
.language_mode(Language::English)
.build();
Self {
embedder,
scorer: Scorer::new(),
title_scorer: Scorer::new(),
summary_scorer: Scorer::new(),
content_scorer: Scorer::new(),
weights: FieldWeights::default(),
doc_count: 0,
fitted: false,
}
}
pub fn fit_to_corpus(documents: &[FieldDocument<K>]) -> Self {
let corpus: Vec<String> = documents.iter().map(|d| d.combined_text()).collect();
let corpus_refs: Vec<&str> = corpus.iter().map(|s| s.as_str()).collect();
let embedder = EmbedderBuilder::with_fit_to_corpus(Language::English, &corpus_refs).build();
let mut engine = Self {
embedder,
scorer: Scorer::new(),
title_scorer: Scorer::new(),
summary_scorer: Scorer::new(),
content_scorer: Scorer::new(),
weights: FieldWeights::default(),
doc_count: 0,
fitted: true,
};
for doc in documents {
engine.upsert(doc);
}
engine
}
pub fn with_weights(mut self, weights: FieldWeights) -> Self {
self.weights = weights;
self
}
pub fn with_language(mut self, language: Language) -> Self {
self.embedder = EmbedderBuilder::with_avgdl(self.embedder.avgdl())
.language_mode(language)
.build();
self
}
pub fn avgdl(&self) -> f32 {
self.embedder.avgdl()
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn upsert(&mut self, document: &FieldDocument<K>) {
let id = &document.id;
let title_emb = self.embedder.embed(&document.title);
let summary_emb = self.embedder.embed(&document.summary);
let content_emb = self.embedder.embed(&document.content);
self.title_scorer.upsert(id, title_emb);
self.summary_scorer.upsert(id, summary_emb);
self.content_scorer.upsert(id, content_emb);
let combined = self.embedder.embed(&document.combined_text());
self.scorer.upsert(id, combined);
self.doc_count += 1;
}
pub fn remove(&mut self, id: &K) {
self.scorer.remove(id);
self.title_scorer.remove(id);
self.summary_scorer.remove(id);
self.content_scorer.remove(id);
self.doc_count = self.doc_count.saturating_sub(1);
}
pub fn len(&self) -> usize {
self.doc_count
}
pub fn is_empty(&self) -> bool {
self.doc_count == 0
}
pub fn score(&self, id: &K, query: &str) -> Option<f32> {
let query_emb = self.embedder.embed(query);
let title_score = self.title_scorer.score(id, &query_emb)?;
let summary_score = self.summary_scorer.score(id, &query_emb)?;
let content_score = self.content_scorer.score(id, &query_emb)?;
let total_weight = self.weights.title + self.weights.summary + self.weights.content;
let weighted_score = (title_score * self.weights.title
+ summary_score * self.weights.summary
+ content_score * self.weights.content)
/ total_weight;
Some(weighted_score)
}
pub fn search(&self, query: &str, limit: usize) -> Vec<ScoredDocument<K>> {
let query_emb = self.embedder.embed(query);
self.scorer
.matches(&query_emb)
.into_iter()
.take(limit)
.collect()
}
pub fn search_weighted(&self, query: &str, limit: usize) -> Vec<(K, f32)> {
let query_emb = self.embedder.embed(query);
let all_results = self.scorer.matches(&query_emb);
let mut scored: Vec<(K, f32)> = all_results
.into_iter()
.filter_map(|scored_doc| {
let id = scored_doc.id;
let title_score = self.title_scorer.score(&id, &query_emb)?;
let summary_score = self.summary_scorer.score(&id, &query_emb)?;
let content_score = self.content_scorer.score(&id, &query_emb)?;
let total_weight = self.weights.title + self.weights.summary + self.weights.content;
let weighted_score = (title_score * self.weights.title
+ summary_score * self.weights.summary
+ content_score * self.weights.content)
/ total_weight;
Some((id, weighted_score))
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
scored
}
pub fn tokenize(&self, text: &str) -> Vec<String> {
let tokenizer = DefaultTokenizer::builder()
.language_mode(Language::English)
.normalization(true)
.stopwords(true)
.stemming(true)
.build();
tokenizer.tokenize(text)
}
pub fn embedder(&self) -> &Embedder {
&self.embedder
}
pub fn embedder_mut(&mut self) -> &mut Embedder {
&mut self.embedder
}
}
impl<K: std::hash::Hash + Eq + Clone + std::fmt::Debug> Default for Bm25Engine<K> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ExpandedQuery {
pub original: String,
pub expansions: Vec<String>,
pub combined: String,
}
impl ExpandedQuery {
pub fn new(original: String, expansions: Vec<String>) -> Self {
let combined = format!("{} {}", original, expansions.join(" "));
Self {
original,
expansions,
combined,
}
}
}
#[async_trait::async_trait]
pub trait QueryExpander: Send + Sync {
async fn expand(&self, query: &str) -> ExpandedQuery;
}
pub const STOPWORDS: &[&str] = &[
"a",
"an",
"the",
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
"may",
"might",
"must",
"shall",
"can",
"need",
"dare",
"ought",
"used",
"to",
"of",
"in",
"for",
"on",
"with",
"at",
"by",
"from",
"as",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"again",
"further",
"then",
"once",
"here",
"there",
"when",
"where",
"why",
"how",
"all",
"each",
"few",
"more",
"most",
"other",
"some",
"such",
"no",
"nor",
"not",
"only",
"own",
"same",
"so",
"than",
"too",
"very",
"just",
"and",
"but",
"if",
"or",
"because",
"until",
"while",
"about",
"what",
"which",
"who",
"whom",
"this",
"that",
"these",
"those",
"i",
"me",
"my",
"myself",
"we",
"our",
"ours",
"ourselves",
"you",
"your",
"yours",
"yourself",
"yourselves",
"he",
"him",
"his",
"himself",
"she",
"her",
"hers",
"herself",
"it",
"its",
"itself",
"they",
"them",
"their",
"theirs",
"themselves",
];
#[must_use]
pub fn extract_keywords(query: &str) -> Vec<String> {
query
.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| {
let s = *s;
!s.is_empty() && s.len() > 1 && !STOPWORDS.contains(&s)
})
.map(String::from)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_engine_creation() {
let engine: Bm25Engine<u32> = Bm25Engine::new();
assert!(engine.is_empty());
assert!(!engine.is_fitted());
}
#[test]
fn test_bm25_engine_fit_to_corpus() {
let docs = vec![
FieldDocument::new(
1u32,
"Rust Programming".to_string(),
"About Rust".to_string(),
"Rust is a systems programming language.".to_string(),
),
FieldDocument::new(
2u32,
"Python Guide".to_string(),
"About Python".to_string(),
"Python is a scripting language.".to_string(),
),
];
let engine = Bm25Engine::fit_to_corpus(&docs);
assert!(engine.is_fitted());
assert_eq!(engine.len(), 2);
}
#[test]
fn test_bm25_search() {
let docs = vec![
FieldDocument::new(
1u32,
"Rust Programming".to_string(),
"About Rust".to_string(),
"Rust is a systems programming language with memory safety.".to_string(),
),
FieldDocument::new(
2u32,
"Python Guide".to_string(),
"About Python".to_string(),
"Python is a scripting language for data science.".to_string(),
),
FieldDocument::new(
3u32,
"Rust Memory Safety".to_string(),
"Memory in Rust".to_string(),
"Rust provides guaranteed memory safety without garbage collection.".to_string(),
),
];
let engine = Bm25Engine::fit_to_corpus(&docs);
let results = engine.search("rust memory", 10);
assert!(!results.is_empty());
assert!(results.iter().any(|r| r.id == 1 || r.id == 3));
}
#[test]
fn test_bm25_weighted_search() {
let docs = vec![
FieldDocument::new(
1u32,
"Rust Programming".to_string(),
"About memory safety".to_string(),
"Content about other things.".to_string(),
),
FieldDocument::new(
2u32,
"Other Language".to_string(),
"About other things".to_string(),
"Rust memory safety is important.".to_string(),
),
];
let engine = Bm25Engine::fit_to_corpus(&docs).with_weights(FieldWeights {
title: 3.0,
summary: 2.0,
content: 1.0,
});
let results = engine.search_weighted("rust", 10);
assert_eq!(results.first().map(|(id, _)| *id), Some(1u32));
}
#[test]
fn test_bm25_score() {
let docs = vec![FieldDocument::new(
1u32,
"Rust Programming".to_string(),
"About Rust".to_string(),
"Rust is a systems programming language.".to_string(),
)];
let engine = Bm25Engine::fit_to_corpus(&docs);
let score = engine.score(&1u32, "rust programming");
assert!(score.is_some());
assert!(score.unwrap() > 0.0);
}
#[test]
fn test_bm25_tokenize() {
let engine: Bm25Engine<u32> = Bm25Engine::new();
let tokens = engine.tokenize("What is the Rust programming language?");
assert!(tokens.contains(&"rust".to_string()));
assert!(tokens.contains(&"program".to_string())); assert!(!tokens.contains(&"what".to_string())); assert!(!tokens.contains(&"the".to_string())); }
#[test]
fn test_bm25_remove() {
let docs = vec![FieldDocument::new(
1u32,
"Rust".to_string(),
"About Rust".to_string(),
"Rust content.".to_string(),
)];
let mut engine = Bm25Engine::fit_to_corpus(&docs);
assert_eq!(engine.len(), 1);
engine.remove(&1u32);
assert!(engine.is_empty());
}
#[test]
fn test_field_weights_default() {
let weights = FieldWeights::default();
assert!((weights.title - 2.0).abs() < f32::EPSILON);
assert!((weights.summary - 1.5).abs() < f32::EPSILON);
assert!((weights.content - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_bm25_params_default() {
let params = Bm25Params::default();
assert!((params.k1 - 1.2).abs() < f32::EPSILON);
assert!((params.b - 0.75).abs() < f32::EPSILON);
assert!((params.avgdl - 100.0).abs() < f32::EPSILON);
}
#[test]
fn test_expanded_query() {
let expanded = ExpandedQuery::new(
"rust".to_string(),
vec!["programming".to_string(), "language".to_string()],
);
assert_eq!(expanded.original, "rust");
assert_eq!(expanded.expansions.len(), 2);
assert_eq!(expanded.combined, "rust programming language");
}
}