use std::collections::{HashMap, HashSet};
use rust_stemmers::{Algorithm, Stemmer};
use serde::{Deserialize, Serialize};
macro_rules! define_stem_languages {
( default = $default:ident; $( $variant:ident => $canonical:literal, $algo:ident; )+ ) => {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StemLanguage { $( $variant, )+ }
impl Default for StemLanguage {
fn default() -> Self { Self::$default }
}
impl StemLanguage {
pub(crate) fn to_algorithm(self) -> Algorithm {
match self { $( Self::$variant => Algorithm::$algo, )+ }
}
pub fn canonical_name(self) -> &'static str {
match self { $( Self::$variant => $canonical, )+ }
}
}
pub fn parse_language(s: &str) -> anyhow::Result<StemLanguage> {
let lower = s.to_lowercase();
match lower.as_str() {
$( $canonical => return Ok(StemLanguage::$variant), )+
_ => {}
}
match iso639_to_canonical(lower.as_str()) {
Some(lang) => Ok(lang),
None => anyhow::bail!(
"unknown stemming language: {s:?} (use a full name like \"english\" or an ISO 639-1 code like \"en\")"
),
}
}
};
}
fn iso639_to_canonical(code: &str) -> Option<StemLanguage> {
match code {
"ar" => Some(StemLanguage::Arabic),
"da" => Some(StemLanguage::Danish),
"nl" => Some(StemLanguage::Dutch),
"en" => Some(StemLanguage::English),
"fi" => Some(StemLanguage::Finnish),
"fr" => Some(StemLanguage::French),
"de" => Some(StemLanguage::German),
"el" => Some(StemLanguage::Greek),
"hu" => Some(StemLanguage::Hungarian),
"it" => Some(StemLanguage::Italian),
"no" | "nb" | "nn" => Some(StemLanguage::Norwegian),
"pt" => Some(StemLanguage::Portuguese),
"ro" => Some(StemLanguage::Romanian),
"ru" => Some(StemLanguage::Russian),
"es" => Some(StemLanguage::Spanish),
"sv" => Some(StemLanguage::Swedish),
"ta" => Some(StemLanguage::Tamil),
"tr" => Some(StemLanguage::Turkish),
_ => None,
}
}
define_stem_languages! {
default = English;
Arabic => "arabic", Arabic;
Danish => "danish", Danish;
Dutch => "dutch", Dutch;
English => "english", English;
Finnish => "finnish", Finnish;
French => "french", French;
German => "german", German;
Greek => "greek", Greek;
Hungarian => "hungarian", Hungarian;
Italian => "italian", Italian;
Norwegian => "norwegian", Norwegian;
Portuguese => "portuguese", Portuguese;
Romanian => "romanian", Romanian;
Russian => "russian", Russian;
Spanish => "spanish", Spanish;
Swedish => "swedish", Swedish;
Tamil => "tamil", Tamil;
Turkish => "turkish", Turkish;
}
pub fn resolve_language(
frontmatter_lang: Option<&str>,
cli_lang: Option<&str>,
config_lang: Option<&str>,
) -> StemLanguage {
for lang_str in [frontmatter_lang, cli_lang, config_lang]
.into_iter()
.flatten()
{
if let Ok(lang) = parse_language(lang_str) {
return lang;
}
}
StemLanguage::default()
}
pub fn create_stemmer(lang: StemLanguage) -> Stemmer {
Stemmer::create(lang.to_algorithm())
}
pub fn tokenize(text: &str, stemmer: &Stemmer) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(|word| {
let lower = word.to_lowercase();
stemmer.stem(&lower).into_owned()
})
.collect()
}
pub struct PreTokenizedInput {
pub rel_path: String,
pub tokens: Vec<String>,
}
pub struct DocumentInput {
pub rel_path: String,
pub title: String,
pub body: String,
pub language: StemLanguage,
}
pub fn tokenize_document(doc: DocumentInput) -> PreTokenizedInput {
let stemmer = Stemmer::create(doc.language.to_algorithm());
let combined = format!("{} {}", doc.title, doc.body);
let tokens = tokenize(&combined, &stemmer);
PreTokenizedInput {
rel_path: doc.rel_path,
tokens,
}
}
pub struct Bm25Match {
pub rel_path: String,
pub score: f64,
}
#[derive(Debug, Clone, PartialEq)]
enum Clause {
Must(Vec<String>),
Should(Vec<String>),
MustNot(Vec<String>),
}
#[derive(Debug, Clone)]
struct BooleanQuery {
clauses: Vec<Clause>,
}
impl BooleanQuery {
fn is_empty(&self) -> bool {
self.clauses.is_empty()
}
fn has_positive_clauses(&self) -> bool {
self.clauses
.iter()
.any(|c| matches!(c, Clause::Must(_) | Clause::Should(_)))
}
fn must_not_terms(&self) -> Vec<&str> {
self.clauses
.iter()
.filter_map(|c| match c {
Clause::MustNot(terms) => Some(terms.iter().map(String::as_str)),
_ => None,
})
.flatten()
.collect()
}
}
#[derive(Debug)]
enum QuerySegment {
Term(String),
Phrase(String),
Negated(String),
Or,
And,
}
fn tokenize_query_segments(query: &str) -> Vec<QuerySegment> {
let mut segments = Vec::new();
let mut chars = query.chars().peekable();
while let Some(&ch) = chars.peek() {
if ch.is_whitespace() {
chars.next();
continue;
}
if ch == '"' {
chars.next(); let mut phrase = String::new();
for c in chars.by_ref() {
if c == '"' {
break;
}
phrase.push(c);
}
if !phrase.is_empty() {
segments.push(QuerySegment::Phrase(phrase));
}
} else if ch == '-' {
chars.next(); if chars.peek() == Some(&'"') {
chars.next(); let mut phrase = String::new();
for c in chars.by_ref() {
if c == '"' {
break;
}
phrase.push(c);
}
if !phrase.is_empty() {
segments.push(QuerySegment::Negated(phrase));
}
} else {
let mut word = String::new();
while let Some(&c) = chars.peek() {
if c.is_whitespace() || c == '"' {
break;
}
word.push(c);
chars.next();
}
if !word.is_empty() {
segments.push(QuerySegment::Negated(word));
}
}
} else {
let mut word = String::new();
while let Some(&c) = chars.peek() {
if c.is_whitespace() || c == '"' {
break;
}
word.push(c);
chars.next();
}
if word.eq_ignore_ascii_case("or") {
segments.push(QuerySegment::Or);
} else if word.eq_ignore_ascii_case("and") {
segments.push(QuerySegment::And);
} else if !word.is_empty() {
segments.push(QuerySegment::Term(word));
}
}
}
segments
}
fn parse_boolean_query(query: &str, stemmer: &Stemmer) -> BooleanQuery {
let raw = tokenize_query_segments(query);
let has_or = raw.iter().any(|s| matches!(s, QuerySegment::Or));
let mut clauses = Vec::new();
for seg in raw {
match seg {
QuerySegment::Or | QuerySegment::And => {}
QuerySegment::Negated(text) => {
let tokens = tokenize(&text, stemmer);
if !tokens.is_empty() {
clauses.push(Clause::MustNot(tokens));
}
}
QuerySegment::Phrase(text) => {
let tokens = tokenize(&text, stemmer);
if !tokens.is_empty() {
if has_or {
clauses.push(Clause::Should(tokens));
} else {
clauses.push(Clause::Must(tokens));
}
}
}
QuerySegment::Term(text) => {
let tokens = tokenize(&text, stemmer);
if !tokens.is_empty() {
for token in tokens {
if has_or {
clauses.push(Clause::Should(vec![token]));
} else {
clauses.push(Clause::Must(vec![token]));
}
}
}
}
}
}
BooleanQuery { clauses }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Posting {
pub(crate) doc_id: u32,
pub(crate) term_freq: u32,
pub(crate) positions: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Bm25InvertedIndex {
postings: HashMap<String, Vec<Posting>>,
doc_lengths: Vec<u32>,
doc_paths: Vec<String>,
avgdl: f64,
}
impl Bm25InvertedIndex {
const K1: f64 = 1.2;
const B: f64 = 0.75;
pub fn build(docs: Vec<DocumentInput>) -> Self {
let pre_tokenized: Vec<PreTokenizedInput> =
docs.into_iter().map(tokenize_document).collect();
Self::build_from_tokens(pre_tokenized)
}
pub fn build_from_tokens(docs: Vec<PreTokenizedInput>) -> Self {
let n = docs.len();
let mut postings: HashMap<String, Vec<Posting>> = HashMap::new();
let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
let mut doc_paths: Vec<String> = Vec::with_capacity(n);
for (doc_id, doc) in docs.into_iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let doc_id = doc_id as u32;
let token_count = doc.tokens.len();
let mut tf: HashMap<&str, (u32, Vec<u32>)> = HashMap::new();
for (pos, token) in doc.tokens.iter().enumerate() {
let entry = tf.entry(token.as_str()).or_insert_with(|| (0, Vec::new()));
entry.0 += 1;
#[allow(clippy::cast_possible_truncation)]
entry.1.push(pos as u32);
}
for (term, (freq, positions)) in tf {
postings.entry(term.to_owned()).or_default().push(Posting {
doc_id,
term_freq: freq,
positions,
});
}
#[allow(clippy::cast_possible_truncation)]
doc_lengths.push(token_count as u32);
doc_paths.push(doc.rel_path);
}
#[allow(clippy::cast_precision_loss)]
let avgdl: f64 = if n == 0 {
256.0
} else {
let total: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
total as f64 / n as f64
};
Self {
postings,
doc_lengths,
doc_paths,
avgdl,
}
}
pub fn build_from_entries(entries: &[crate::index::IndexEntry]) -> Option<Self> {
let docs: Vec<PreTokenizedInput> = entries
.iter()
.filter_map(|e| {
e.bm25_tokens.as_ref().map(|tokens| PreTokenizedInput {
rel_path: e.rel_path.clone(),
tokens: tokens.clone(),
})
})
.collect();
if docs.is_empty() {
return None;
}
Some(Self::build_from_tokens(docs))
}
pub(crate) fn total_postings(&self) -> usize {
self.postings
.values()
.flat_map(|posts| posts.iter())
.map(|p| 1 + p.positions.len())
.sum()
}
pub fn score(&self, query: &str, stemmer: &Stemmer) -> Vec<Bm25Match> {
self.ranked_matches(query, stemmer)
}
pub fn doc_count(&self) -> usize {
self.doc_paths.len()
}
pub(crate) fn validate_doc_ids(&self) -> bool {
let max_id = self.doc_paths.len();
if self.doc_lengths.len() != max_id {
return false;
}
self.postings
.values()
.all(|posts| posts.iter().all(|p| (p.doc_id as usize) < max_id))
}
#[cfg(test)]
pub(crate) fn new_for_test(
postings: HashMap<String, Vec<Posting>>,
doc_lengths: Vec<u32>,
doc_paths: Vec<String>,
avgdl: f64,
) -> Self {
Self {
postings,
doc_lengths,
doc_paths,
avgdl,
}
}
fn ranked_matches(&self, query: &str, stemmer: &Stemmer) -> Vec<Bm25Match> {
let query = parse_boolean_query(query, stemmer);
if query.is_empty() || !query.has_positive_clauses() {
return Vec::new();
}
#[allow(clippy::cast_precision_loss)]
let n = self.doc_paths.len() as f64;
let avgdl = self.avgdl;
let must_not_terms = query.must_not_terms();
let excluded: HashSet<u32> = must_not_terms
.iter()
.filter_map(|t| self.postings.get(*t))
.flat_map(|posts| posts.iter().map(|p| p.doc_id))
.collect();
let mut scores: HashMap<u32, f64> = HashMap::new();
let must_clause_count = query
.clauses
.iter()
.filter(|c| matches!(c, Clause::Must(_)))
.count();
let mut must_hits: HashMap<u32, usize> = HashMap::new();
let mut phrase_rejected: HashSet<u32> = HashSet::new();
for clause in &query.clauses {
let (terms, is_must) = match clause {
Clause::Must(t) => (t, true),
Clause::Should(t) => (t, false),
Clause::MustNot(_) => continue,
};
let is_phrase = terms.len() > 1;
for term in terms {
let Some(posting_list) = self.postings.get(term) else {
continue;
};
#[allow(clippy::cast_precision_loss)]
let nt = posting_list.len() as f64;
let idf = (1.0 + (n - nt + 0.5) / (nt + 0.5)).ln();
for p in posting_list {
if excluded.contains(&p.doc_id) {
continue;
}
let tf = f64::from(p.term_freq);
let dl = f64::from(self.doc_lengths[p.doc_id as usize]);
let tf_norm = (tf * (Self::K1 + 1.0))
/ (tf + Self::K1 * (1.0 - Self::B + Self::B * dl / avgdl));
*scores.entry(p.doc_id).or_insert(0.0) += idf * tf_norm;
}
}
if is_phrase {
let phrase_docs: HashSet<u32> = self
.docs_with_all_terms(terms)
.into_iter()
.filter(|&doc_id| self.has_phrase_at_positions(terms, doc_id))
.collect();
if is_must {
for &doc_id in &phrase_docs {
*must_hits.entry(doc_id).or_insert(0) += 1;
}
}
let all_term_docs: HashSet<u32> = self.docs_with_all_terms(terms);
for doc_id in all_term_docs {
if !phrase_docs.contains(&doc_id) {
phrase_rejected.insert(doc_id);
}
}
} else if is_must {
if let Some(posting_list) = self.postings.get(&terms[0]) {
for p in posting_list {
*must_hits.entry(p.doc_id).or_insert(0) += 1;
}
}
}
}
let mut matches: Vec<Bm25Match> = scores
.into_iter()
.filter(|(doc_id, score)| {
if *score <= 0.0 {
return false;
}
if phrase_rejected.contains(doc_id) {
return false;
}
if must_clause_count > 0 {
must_hits.get(doc_id).copied().unwrap_or(0) >= must_clause_count
} else {
true }
})
.map(|(doc_id, score)| Bm25Match {
rel_path: self.doc_paths[doc_id as usize].clone(),
score,
})
.collect();
matches.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
matches
}
fn docs_with_all_terms(&self, terms: &[String]) -> HashSet<u32> {
let mut iter = terms.iter().filter_map(|t| self.postings.get(t));
let first = match iter.next() {
Some(postings) => postings.iter().map(|p| p.doc_id).collect::<HashSet<u32>>(),
None => return HashSet::new(),
};
iter.fold(first, |acc, postings| {
let ids: HashSet<u32> = postings.iter().map(|p| p.doc_id).collect();
acc.intersection(&ids).copied().collect()
})
}
fn has_phrase_at_positions(&self, terms: &[String], doc_id: u32) -> bool {
if terms.is_empty() {
return false;
}
let positions: Vec<&[u32]> = terms
.iter()
.map(|t| {
self.postings
.get(t)
.and_then(|ps| {
ps.binary_search_by_key(&doc_id, |p| p.doc_id)
.ok()
.map(|idx| ps[idx].positions.as_slice())
})
.unwrap_or(&[] as &[u32])
})
.collect();
if positions.iter().any(|p| p.is_empty()) {
return false;
}
for &start_pos in positions[0] {
let found = positions.iter().enumerate().skip(1).all(|(i, pos_list)| {
#[allow(clippy::cast_possible_truncation)]
let target = start_pos + i as u32;
pos_list.binary_search(&target).is_ok()
});
if found {
return true;
}
}
false
}
}
pub fn query_is_operator_only(query: &str) -> bool {
let segments = tokenize_query_segments(query);
if segments.is_empty() {
return false;
}
segments
.iter()
.all(|s| matches!(s, QuerySegment::Or | QuerySegment::And))
}
pub fn is_low_discriminative(matches: &[Bm25Match], total_docs: usize) -> bool {
if total_docs == 0 || matches.is_empty() {
return false;
}
debug_assert!(
matches.len() <= total_docs,
"matches ({}) exceeds total_docs ({})",
matches.len(),
total_docs
);
#[allow(clippy::cast_precision_loss)]
let match_ratio = matches.len() as f64 / total_docs as f64;
let max_score = matches.iter().map(|m| m.score).fold(0.0_f64, f64::max);
match_ratio > 0.8 && max_score < 1.0
}
#[cfg(test)]
mod tests {
use super::*;
fn make_stemmer(lang: StemLanguage) -> Stemmer {
Stemmer::create(lang.to_algorithm())
}
#[test]
fn test_tokenize_english() {
let stemmer = make_stemmer(StemLanguage::English);
let tokens = tokenize("running quickly", &stemmer);
assert_eq!(tokens, vec!["run", "quick"]);
}
#[test]
fn test_tokenize_french() {
let stemmer = make_stemmer(StemLanguage::French);
let tokens = tokenize("mangeons rapidement", &stemmer);
assert_eq!(tokens.len(), 2);
assert!(tokens.iter().all(|t| !t.is_empty()));
}
#[test]
fn test_tokenize_splits_on_punctuation() {
let stemmer = make_stemmer(StemLanguage::English);
let tokens = tokenize("hello-world foo_bar", &stemmer);
assert_eq!(tokens.len(), 4);
}
#[test]
fn test_tokenize_unicode_lowercase() {
let stemmer = make_stemmer(StemLanguage::English);
let tokens = tokenize("ÜBER", &stemmer);
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0], tokens[0].to_lowercase());
}
#[test]
fn test_tokenize_per_token_lowercase() {
let stemmer = make_stemmer(StemLanguage::English);
let tokens = tokenize("RUST Programming", &stemmer);
assert_eq!(tokens.len(), 2);
for t in &tokens {
assert_eq!(t, &t.to_lowercase(), "token {t:?} should be lowercase");
}
}
#[test]
fn test_tokenize_document_function() {
let input = DocumentInput {
rel_path: "test.md".to_owned(),
title: "Running guide".to_owned(),
body: "A guide to running faster.".to_owned(),
language: StemLanguage::English,
};
let pre = tokenize_document(input);
assert_eq!(pre.rel_path, "test.md");
assert!(pre.tokens.contains(&"run".to_owned()));
}
#[test]
fn test_resolve_language_precedence() {
assert_eq!(
resolve_language(Some("french"), Some("german"), Some("spanish")),
StemLanguage::French
);
assert_eq!(
resolve_language(None, Some("german"), Some("spanish")),
StemLanguage::German
);
assert_eq!(
resolve_language(None, None, Some("spanish")),
StemLanguage::Spanish
);
assert_eq!(resolve_language(None, None, None), StemLanguage::English);
assert_eq!(
resolve_language(Some("klingon"), Some("italian"), None),
StemLanguage::Italian
);
}
#[test]
fn test_parse_language_valid() {
let cases = [
("arabic", StemLanguage::Arabic),
("Danish", StemLanguage::Danish),
("DUTCH", StemLanguage::Dutch),
("English", StemLanguage::English),
("finnish", StemLanguage::Finnish),
("French", StemLanguage::French),
("german", StemLanguage::German),
("greek", StemLanguage::Greek),
("Hungarian", StemLanguage::Hungarian),
("Italian", StemLanguage::Italian),
("norwegian", StemLanguage::Norwegian),
("portuguese", StemLanguage::Portuguese),
("romanian", StemLanguage::Romanian),
("russian", StemLanguage::Russian),
("spanish", StemLanguage::Spanish),
("Swedish", StemLanguage::Swedish),
("Tamil", StemLanguage::Tamil),
("Turkish", StemLanguage::Turkish),
];
for (input, expected) in cases {
assert_eq!(parse_language(input).expect(input), expected, "{input}");
}
}
#[test]
fn test_parse_language_invalid() {
assert!(parse_language("klingon").is_err());
assert!(parse_language("").is_err());
assert!(parse_language("xx").is_err());
}
#[test]
fn test_parse_language_iso639_codes() {
assert_eq!(parse_language("en").unwrap(), StemLanguage::English);
assert_eq!(parse_language("de").unwrap(), StemLanguage::German);
assert_eq!(parse_language("fr").unwrap(), StemLanguage::French);
assert_eq!(parse_language("es").unwrap(), StemLanguage::Spanish);
assert_eq!(parse_language("ar").unwrap(), StemLanguage::Arabic);
assert_eq!(parse_language("no").unwrap(), StemLanguage::Norwegian);
assert_eq!(parse_language("nb").unwrap(), StemLanguage::Norwegian);
assert_eq!(parse_language("EN").unwrap(), StemLanguage::English);
}
#[test]
fn test_parse_boolean_query_and() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query("foo bar", &stemmer);
let must_terms: Vec<_> = q
.clauses
.iter()
.filter_map(|c| match c {
Clause::Must(t) => Some(t.clone()),
_ => None,
})
.collect();
assert_eq!(must_terms.len(), 2, "expected two Must clauses, got: {q:?}");
assert!(!q.clauses.iter().any(|c| matches!(c, Clause::Should(_))));
}
#[test]
fn test_parse_boolean_query_or() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query("foo OR bar", &stemmer);
let should_terms: Vec<_> = q
.clauses
.iter()
.filter_map(|c| match c {
Clause::Should(t) => Some(t.clone()),
_ => None,
})
.collect();
assert_eq!(should_terms.len(), 2, "expected two Should clauses");
assert!(!q.clauses.iter().any(|c| matches!(c, Clause::Must(_))));
}
#[test]
fn test_parse_boolean_query_not() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query("-foo", &stemmer);
assert_eq!(q.clauses.len(), 1);
assert!(matches!(&q.clauses[0], Clause::MustNot(_)));
assert!(!q.has_positive_clauses());
}
#[test]
fn test_parse_boolean_query_phrase() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query("\"foo bar\"", &stemmer);
assert_eq!(q.clauses.len(), 1);
match &q.clauses[0] {
Clause::Must(tokens) => {
assert_eq!(tokens.len(), 2, "phrase should produce two tokens");
}
other => panic!("expected Must, got {other:?}"),
}
}
#[test]
fn test_parse_boolean_query_mixed() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query("foo OR bar -baz", &stemmer);
let should_count = q
.clauses
.iter()
.filter(|c| matches!(c, Clause::Should(_)))
.count();
let must_not_count = q
.clauses
.iter()
.filter(|c| matches!(c, Clause::MustNot(_)))
.count();
assert_eq!(should_count, 2, "foo and bar should be Should");
assert_eq!(must_not_count, 1, "baz should be MustNot");
assert!(!q.clauses.iter().any(|c| matches!(c, Clause::Must(_))));
}
#[test]
fn test_parse_boolean_query_case_insensitive_or() {
let stemmer = make_stemmer(StemLanguage::English);
let q_lower = parse_boolean_query("foo or bar", &stemmer);
let q_upper = parse_boolean_query("foo OR bar", &stemmer);
assert_eq!(q_lower.clauses.len(), q_upper.clauses.len());
assert!(
q_lower
.clauses
.iter()
.all(|c| matches!(c, Clause::Should(_)))
);
}
#[test]
fn test_parse_boolean_query_and_keyword_ignored() {
let stemmer = make_stemmer(StemLanguage::English);
let q_explicit = parse_boolean_query("foo AND bar", &stemmer);
let q_implicit = parse_boolean_query("foo bar", &stemmer);
assert_eq!(q_explicit.clauses.len(), q_implicit.clauses.len());
assert!(
q_explicit
.clauses
.iter()
.all(|c| matches!(c, Clause::Must(_)))
);
}
#[test]
fn test_parse_boolean_query_empty() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query("", &stemmer);
assert!(q.is_empty());
assert!(!q.has_positive_clauses());
}
#[test]
fn test_parse_boolean_query_whitespace_only() {
let stemmer = make_stemmer(StemLanguage::English);
let q = parse_boolean_query(" ", &stemmer);
assert!(q.is_empty());
}
fn doc(rel_path: &str, title: &str, body: &str) -> DocumentInput {
DocumentInput {
rel_path: rel_path.to_owned(),
title: title.to_owned(),
body: body.to_owned(),
language: StemLanguage::English,
}
}
#[test]
fn test_bm25_corpus_basic_search() {
let docs = vec![
doc(
"rust.md",
"Rust programming",
"Rust is a systems programming language.",
),
doc(
"python.md",
"Python programming",
"Python is a scripting language.",
),
doc(
"cooking.md",
"Cooking recipes",
"How to bake a delicious cake.",
),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("rust programming", &stemmer);
assert!(!results.is_empty(), "expected at least one result");
assert_eq!(results[0].rel_path, "rust.md");
}
#[test]
fn test_bm25_corpus_stemming_matches() {
let docs = vec![
doc(
"running.md",
"Running guide",
"I enjoy running every morning.",
),
doc(
"cooking.md",
"Cooking guide",
"I enjoy cooking every evening.",
),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("run", &stemmer);
assert!(!results.is_empty(), "expected matches via stemming");
assert_eq!(results[0].rel_path, "running.md");
}
#[test]
fn test_bm25_corpus_relevance_ranking() {
let docs = vec![
doc(
"many.md",
"Rust tips",
"Rust Rust Rust Rust Rust is great for systems programming.",
),
doc(
"few.md",
"Languages",
"Rust is one option among many languages.",
),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("rust", &stemmer);
assert!(results.len() >= 2);
assert_eq!(results[0].rel_path, "many.md");
}
#[test]
fn test_bm25_corpus_empty_query() {
let docs = vec![doc("a.md", "Title", "Some body text.")];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("", &stemmer);
assert!(results.is_empty());
}
#[test]
fn test_bm25_corpus_no_matches() {
let docs = vec![doc("a.md", "Title", "Some body text.")];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("xyzzy42quux", &stemmer);
assert!(results.is_empty());
}
#[test]
fn test_bm25_corpus_single_doc() {
let docs = vec![doc(
"single.md",
"Only document",
"This is the only document.",
)];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("document", &stemmer);
assert_eq!(results.len(), 1);
assert_eq!(results[0].rel_path, "single.md");
}
#[test]
fn test_bm25_corpus_score_returns_all() {
let docs = vec![
doc("a.md", "Alpha", "The quick brown fox."),
doc("b.md", "Beta", "The lazy dog slept."),
doc("c.md", "Gamma", "No matching content here."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let all = index.score("quick", &stemmer);
assert_eq!(all.len(), 1);
assert_eq!(all[0].rel_path, "a.md");
}
#[test]
fn test_bm25_and_scoring() {
let docs = vec![
doc("both.md", "Alpha beta", "This doc has alpha and beta."),
doc("alpha_only.md", "Alpha", "This doc has only alpha."),
doc("beta_only.md", "Beta", "This doc has only beta."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("alpha beta", &stemmer);
let paths: Vec<&str> = results.iter().map(|r| r.rel_path.as_str()).collect();
assert!(
paths.contains(&"both.md"),
"both.md should match: {paths:?}"
);
assert!(
!paths.contains(&"alpha_only.md"),
"alpha_only.md should be excluded: {paths:?}"
);
assert!(
!paths.contains(&"beta_only.md"),
"beta_only.md should be excluded: {paths:?}"
);
}
#[test]
fn test_bm25_or_scoring() {
let docs = vec![
doc("rust.md", "Rust", "Rust systems programming."),
doc("python.md", "Python", "Python scripting language."),
doc("cooking.md", "Cooking", "Recipes and ingredients."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("rust OR python", &stemmer);
let paths: Vec<&str> = results.iter().map(|r| r.rel_path.as_str()).collect();
assert!(
paths.contains(&"rust.md"),
"rust.md should match: {paths:?}"
);
assert!(
paths.contains(&"python.md"),
"python.md should match: {paths:?}"
);
assert!(
!paths.contains(&"cooking.md"),
"cooking.md should not match: {paths:?}"
);
}
#[test]
fn test_bm25_phrase_matching() {
let docs = vec![
doc(
"consecutive.md",
"Fast Rust",
"Rust is a fast systems language.",
),
doc(
"non_consecutive.md",
"Systems Language",
"A fast language but not Rust specific.",
),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("\"fast rust\"", &stemmer);
let paths: Vec<&str> = results.iter().map(|r| r.rel_path.as_str()).collect();
assert!(
paths.contains(&"consecutive.md"),
"consecutive.md should match phrase: {paths:?}"
);
assert!(
!paths.contains(&"non_consecutive.md"),
"non_consecutive.md should not match phrase: {paths:?}"
);
}
#[test]
fn test_bm25_phrase_in_or_context() {
let docs = vec![
doc(
"consecutive.md",
"Fast Rust",
"Rust is a fast systems language.",
),
doc(
"non_consecutive.md",
"Systems Language",
"A fast language but not Rust specific.",
),
doc("cooking.md", "Cooking", "Recipes for healthy meals."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("\"fast rust\" OR cooking", &stemmer);
let paths: Vec<&str> = results.iter().map(|r| r.rel_path.as_str()).collect();
assert!(
paths.contains(&"consecutive.md"),
"consecutive.md should match phrase in OR: {paths:?}"
);
assert!(
paths.contains(&"cooking.md"),
"cooking.md should match 'cooking' via OR: {paths:?}"
);
assert!(
!paths.contains(&"non_consecutive.md"),
"non_consecutive.md should NOT match phrase in OR (words not adjacent): {paths:?}"
);
}
#[test]
fn test_bm25_negation_via_postings() {
let docs = vec![
doc("rust.md", "Rust", "Rust is a systems programming language."),
doc(
"python.md",
"Python",
"Python is a scripting programming language.",
),
doc("go.md", "Go", "Go is a compiled programming language."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("programming -python", &stemmer);
let paths: Vec<&str> = results.iter().map(|r| r.rel_path.as_str()).collect();
assert!(
!paths.contains(&"python.md"),
"python.md should be excluded: {paths:?}"
);
assert!(
paths.contains(&"rust.md"),
"rust.md should remain: {paths:?}"
);
assert!(paths.contains(&"go.md"), "go.md should remain: {paths:?}");
}
#[test]
fn test_bm25_negation_excludes_matching_docs() {
let docs = vec![
doc("rust.md", "Rust", "Rust is a systems programming language."),
doc(
"python.md",
"Python",
"Python is a scripting programming language.",
),
doc("go.md", "Go", "Go is a compiled programming language."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("programming -python", &stemmer);
let paths: Vec<&str> = results.iter().map(|r| r.rel_path.as_str()).collect();
assert!(
!paths.contains(&"python.md"),
"python.md should be excluded: {paths:?}"
);
assert!(
paths.contains(&"rust.md"),
"rust.md should remain: {paths:?}"
);
assert!(paths.contains(&"go.md"), "go.md should remain: {paths:?}");
}
#[test]
fn test_bm25_negation_only_returns_empty() {
let docs = vec![doc("a.md", "Title", "Some body text.")];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("-text", &stemmer);
assert!(
results.is_empty(),
"negation-only query should return empty"
);
}
#[test]
fn test_bm25_negation_with_stemming() {
let docs = vec![
doc("a.md", "Running", "I love running every day."),
doc("b.md", "Swimming", "I love swimming every day."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("love -running", &stemmer);
assert_eq!(results.len(), 1);
assert_eq!(results[0].rel_path, "b.md");
}
#[test]
fn test_bm25_empty_corpus() {
let index = Bm25InvertedIndex::build(vec![]);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("anything", &stemmer);
assert!(results.is_empty());
}
#[test]
fn test_bm25_build_from_tokens() {
let docs = vec![
PreTokenizedInput {
rel_path: "a.md".to_owned(),
tokens: vec!["rust".to_owned(), "program".to_owned()],
},
PreTokenizedInput {
rel_path: "b.md".to_owned(),
tokens: vec!["python".to_owned(), "program".to_owned()],
},
];
let index = Bm25InvertedIndex::build_from_tokens(docs);
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("rust", &stemmer);
assert_eq!(results.len(), 1);
assert_eq!(results[0].rel_path, "a.md");
}
#[test]
fn test_bm25_serde_round_trip() {
let docs = vec![
doc("rust.md", "Rust", "Rust is a systems programming language."),
doc("python.md", "Python", "Python is a scripting language."),
];
let index = Bm25InvertedIndex::build(docs);
let stemmer = make_stemmer(StemLanguage::English);
let bytes = rmp_serde::to_vec_named(&index).expect("serialize");
let restored: Bm25InvertedIndex = rmp_serde::from_slice(&bytes).expect("deserialize");
let before = index.score("rust", &stemmer);
let after = restored.score("rust", &stemmer);
assert_eq!(before.len(), after.len());
assert_eq!(before[0].rel_path, after[0].rel_path);
assert!((before[0].score - after[0].score).abs() < f64::EPSILON);
}
#[test]
fn test_bm25_build_from_entries() {
use crate::index::IndexEntry;
use indexmap::IndexMap;
let entries = vec![
IndexEntry {
rel_path: "a.md".to_owned(),
modified: String::new(),
properties: IndexMap::new(),
tags: Vec::new(),
sections: Vec::new(),
tasks: Vec::new(),
links: Vec::new(),
bm25_tokens: Some(vec!["rust".to_owned(), "program".to_owned()]),
bm25_language: Some("english".to_owned()),
},
IndexEntry {
rel_path: "b.md".to_owned(),
modified: String::new(),
properties: IndexMap::new(),
tags: Vec::new(),
sections: Vec::new(),
tasks: Vec::new(),
links: Vec::new(),
bm25_tokens: None, bm25_language: None,
},
];
let index = Bm25InvertedIndex::build_from_entries(&entries);
assert!(index.is_some(), "should build from entries with tokens");
let index = index.unwrap();
let stemmer = make_stemmer(StemLanguage::English);
let results = index.score("rust", &stemmer);
assert_eq!(results.len(), 1);
assert_eq!(results[0].rel_path, "a.md");
}
#[test]
fn test_bm25_build_from_entries_none_when_no_tokens() {
use crate::index::IndexEntry;
use indexmap::IndexMap;
let entries = vec![IndexEntry {
rel_path: "a.md".to_owned(),
modified: String::new(),
properties: IndexMap::new(),
tags: Vec::new(),
sections: Vec::new(),
tasks: Vec::new(),
links: Vec::new(),
bm25_tokens: None,
bm25_language: None,
}];
assert!(
Bm25InvertedIndex::build_from_entries(&entries).is_none(),
"no tokens → should return None"
);
}
fn make_match(rel_path: &str, score: f64) -> Bm25Match {
Bm25Match {
rel_path: rel_path.to_owned(),
score,
}
}
#[test]
fn test_low_discriminative_empty_matches_returns_false() {
assert!(!is_low_discriminative(&[], 10));
}
#[test]
fn test_low_discriminative_zero_total_docs_returns_false() {
let matches = vec![make_match("a.md", 0.5)];
assert!(!is_low_discriminative(&matches, 0));
}
#[test]
fn test_low_discriminative_high_ratio_low_scores_returns_true() {
let matches: Vec<Bm25Match> = (0..9)
.map(|i| make_match(&format!("{i}.md"), 0.5))
.collect();
assert!(is_low_discriminative(&matches, 10));
}
#[test]
fn test_low_discriminative_low_ratio_returns_false() {
let matches: Vec<Bm25Match> = (0..3)
.map(|i| make_match(&format!("{i}.md"), 0.5))
.collect();
assert!(!is_low_discriminative(&matches, 10));
}
#[test]
fn test_low_discriminative_high_ratio_high_scores_returns_false() {
let mut matches: Vec<Bm25Match> = (0..8)
.map(|i| make_match(&format!("{i}.md"), 0.5))
.collect();
matches.push(make_match("high.md", 2.5));
assert!(!is_low_discriminative(&matches, 10));
}
#[test]
fn test_query_is_operator_only_and() {
assert!(
query_is_operator_only("and"),
"bare 'and' should be operator-only"
);
assert!(
query_is_operator_only("AND"),
"case-insensitive 'AND' should be operator-only"
);
}
#[test]
fn test_query_is_operator_only_or() {
assert!(
query_is_operator_only("or"),
"bare 'or' should be operator-only"
);
assert!(
query_is_operator_only("OR"),
"case-insensitive 'OR' should be operator-only"
);
}
#[test]
fn test_query_is_operator_only_multiple_operators() {
assert!(
query_is_operator_only("and or"),
"multiple operator keywords should be operator-only"
);
assert!(
query_is_operator_only("or and"),
"reversed order should still be operator-only"
);
}
#[test]
fn test_query_is_operator_only_mixed_returns_false() {
assert!(
!query_is_operator_only("rust and"),
"mixed query should not be operator-only"
);
assert!(
!query_is_operator_only("foo OR bar"),
"query with real terms should not be operator-only"
);
}
#[test]
fn test_query_is_operator_only_plain_word_returns_false() {
assert!(!query_is_operator_only("rust"), "'rust' is not an operator");
assert!(
!query_is_operator_only("not"),
"'not' is not handled as a boolean operator"
);
}
#[test]
fn test_query_is_operator_only_empty_returns_false() {
assert!(
!query_is_operator_only(""),
"empty query should return false"
);
assert!(
!query_is_operator_only(" "),
"whitespace-only query should return false"
);
}
}