use rust_stemmers::{Algorithm, Stemmer};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum StemLanguage {
#[default]
None,
Arabic,
Danish,
Dutch,
English,
Finnish,
French,
German,
Greek,
Hungarian,
Italian,
Norwegian,
Portuguese,
Romanian,
Russian,
Spanish,
Swedish,
Tamil,
Turkish,
}
impl StemLanguage {
fn to_algorithm(self) -> Option<Algorithm> {
match self {
StemLanguage::None => None,
StemLanguage::Arabic => Some(Algorithm::Arabic),
StemLanguage::Danish => Some(Algorithm::Danish),
StemLanguage::Dutch => Some(Algorithm::Dutch),
StemLanguage::English => Some(Algorithm::English),
StemLanguage::Finnish => Some(Algorithm::Finnish),
StemLanguage::French => Some(Algorithm::French),
StemLanguage::German => Some(Algorithm::German),
StemLanguage::Greek => Some(Algorithm::Greek),
StemLanguage::Hungarian => Some(Algorithm::Hungarian),
StemLanguage::Italian => Some(Algorithm::Italian),
StemLanguage::Norwegian => Some(Algorithm::Norwegian),
StemLanguage::Portuguese => Some(Algorithm::Portuguese),
StemLanguage::Romanian => Some(Algorithm::Romanian),
StemLanguage::Russian => Some(Algorithm::Russian),
StemLanguage::Spanish => Some(Algorithm::Spanish),
StemLanguage::Swedish => Some(Algorithm::Swedish),
StemLanguage::Tamil => Some(Algorithm::Tamil),
StemLanguage::Turkish => Some(Algorithm::Turkish),
}
}
pub fn parse_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"none" | "" => Some(StemLanguage::None),
"arabic" | "ar" => Some(StemLanguage::Arabic),
"danish" | "da" => Some(StemLanguage::Danish),
"dutch" | "nl" => Some(StemLanguage::Dutch),
"english" | "en" => Some(StemLanguage::English),
"finnish" | "fi" => Some(StemLanguage::Finnish),
"french" | "fr" => Some(StemLanguage::French),
"german" | "de" => Some(StemLanguage::German),
"greek" | "el" => Some(StemLanguage::Greek),
"hungarian" | "hu" => Some(StemLanguage::Hungarian),
"italian" | "it" => Some(StemLanguage::Italian),
"norwegian" | "no" => Some(StemLanguage::Norwegian),
"portuguese" | "pt" => Some(StemLanguage::Portuguese),
"romanian" | "ro" => Some(StemLanguage::Romanian),
"russian" | "ru" => Some(StemLanguage::Russian),
"spanish" | "es" => Some(StemLanguage::Spanish),
"swedish" | "sv" => Some(StemLanguage::Swedish),
"tamil" | "ta" => Some(StemLanguage::Tamil),
"turkish" | "tr" => Some(StemLanguage::Turkish),
_ => None,
}
}
pub fn supported_languages() -> &'static [&'static str] {
&[
"arabic",
"danish",
"dutch",
"english",
"finnish",
"french",
"german",
"greek",
"hungarian",
"italian",
"norwegian",
"portuguese",
"romanian",
"russian",
"spanish",
"swedish",
"tamil",
"turkish",
]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FullTextConfig {
pub k1: f32,
pub b: f32,
pub min_token_length: usize,
pub max_token_length: usize,
pub lowercase: bool,
pub stop_words: HashSet<String>,
pub stem_language: StemLanguage,
}
impl Default for FullTextConfig {
fn default() -> Self {
Self {
k1: 1.2,
b: 0.75,
min_token_length: 2,
max_token_length: 50,
lowercase: true,
stop_words: default_stop_words(),
stem_language: StemLanguage::None,
}
}
}
impl FullTextConfig {
pub fn with_english_stemming() -> Self {
Self {
stem_language: StemLanguage::English,
..Default::default()
}
}
pub fn with_language(language: StemLanguage) -> Self {
Self {
stem_language: language,
..Default::default()
}
}
}
fn default_stop_words() -> HashSet<String> {
[
"a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he", "in", "is",
"it", "its", "of", "on", "or", "that", "the", "to", "was", "were", "will", "with", "this",
"but", "they", "have", "had", "what", "when", "where", "who", "which", "why", "how",
]
.iter()
.map(|s| s.to_string())
.collect()
}
#[derive(Debug, Clone, Serialize)]
pub struct TextAnalyzer {
config: FullTextConfig,
#[serde(skip)]
stem_algorithm: Option<Algorithm>,
}
impl<'de> Deserialize<'de> for TextAnalyzer {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct Helper {
config: FullTextConfig,
}
let h = Helper::deserialize(d)?;
Ok(TextAnalyzer::new(h.config))
}
}
impl TextAnalyzer {
pub fn new(config: FullTextConfig) -> Self {
let stem_algorithm = config.stem_language.to_algorithm();
Self {
config,
stem_algorithm,
}
}
pub fn analyze(&self, text: &str) -> Vec<String> {
let text = if self.config.lowercase {
text.to_lowercase()
} else {
text.to_string()
};
let tokens: Vec<String> = text
.split(|c: char| !c.is_alphanumeric())
.filter(|token| {
let len = token.len();
len >= self.config.min_token_length
&& len <= self.config.max_token_length
&& !self.config.stop_words.contains(*token)
})
.map(|s| s.to_string())
.collect();
if let Some(algorithm) = self.stem_algorithm {
let stemmer = Stemmer::create(algorithm);
tokens
.into_iter()
.map(|token| stemmer.stem(&token).to_string())
.collect()
} else {
tokens
}
}
pub fn token_frequencies(&self, text: &str) -> HashMap<String, u32> {
let mut freqs = HashMap::new();
for token in self.analyze(text) {
*freqs.entry(token).or_insert(0) += 1;
}
freqs
}
pub fn stem_language(&self) -> StemLanguage {
self.config.stem_language
}
pub fn stemming_enabled(&self) -> bool {
self.stem_algorithm.is_some()
}
}
impl Default for TextAnalyzer {
fn default() -> Self {
Self::new(FullTextConfig::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Posting {
pub doc_id: String,
pub term_freq: u32,
pub positions: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvertedIndex {
index: HashMap<String, Vec<Posting>>,
doc_lengths: HashMap<String, u32>,
doc_metadata: HashMap<String, serde_json::Value>,
doc_count: u32,
avg_doc_length: f32,
analyzer: TextAnalyzer,
config: FullTextConfig,
}
impl InvertedIndex {
pub fn new(config: FullTextConfig) -> Self {
let analyzer = TextAnalyzer::new(config.clone());
Self {
index: HashMap::new(),
doc_lengths: HashMap::new(),
doc_metadata: HashMap::new(),
doc_count: 0,
avg_doc_length: 0.0,
analyzer,
config,
}
}
pub fn add_document(&mut self, doc_id: &str, text: &str) {
self.add_document_with_metadata(doc_id, text, None);
}
pub fn add_document_with_metadata(
&mut self,
doc_id: &str,
text: &str,
metadata: Option<serde_json::Value>,
) {
self.remove_document(doc_id);
if let Some(meta) = metadata {
self.doc_metadata.insert(doc_id.to_string(), meta);
}
let tokens = self.analyzer.analyze(text);
let doc_length = tokens.len() as u32;
self.doc_lengths.insert(doc_id.to_string(), doc_length);
self.doc_count += 1;
let total_length: u32 = self.doc_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.doc_count as f32;
let mut term_data: HashMap<String, (u32, Vec<u32>)> = HashMap::new();
for (pos, token) in tokens.into_iter().enumerate() {
let entry = term_data.entry(token).or_insert((0, Vec::new()));
entry.0 += 1;
entry.1.push(pos as u32);
}
for (token, (freq, positions)) in term_data {
let posting = Posting {
doc_id: doc_id.to_string(),
term_freq: freq,
positions,
};
self.index.entry(token).or_default().push(posting);
}
}
pub fn remove_document(&mut self, doc_id: &str) -> bool {
if self.doc_lengths.remove(doc_id).is_none() {
return false;
}
self.doc_metadata.remove(doc_id);
self.doc_count = self.doc_count.saturating_sub(1);
if self.doc_count > 0 {
let total_length: u32 = self.doc_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.doc_count as f32;
} else {
self.avg_doc_length = 0.0;
}
for postings in self.index.values_mut() {
postings.retain(|p| p.doc_id != doc_id);
}
self.index.retain(|_, v| !v.is_empty());
true
}
pub fn search(&self, query: &str, top_k: usize) -> Vec<FullTextResult> {
let query_tokens = self.analyzer.analyze(query);
if query_tokens.is_empty() {
return Vec::new();
}
let mut scores: HashMap<String, f32> = HashMap::new();
for token in &query_tokens {
if let Some(postings) = self.index.get(token) {
let idf = self.calculate_idf(postings.len());
for posting in postings {
let doc_length = self.doc_lengths.get(&posting.doc_id).copied().unwrap_or(0);
let tf_score = self.calculate_tf(posting.term_freq, doc_length);
let score = idf * tf_score;
*scores.entry(posting.doc_id.clone()).or_insert(0.0) += score;
}
}
}
let mut results: Vec<_> = scores
.into_iter()
.map(|(doc_id, score)| FullTextResult { doc_id, score })
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
fn calculate_idf(&self, doc_freq: usize) -> f32 {
let n = self.doc_count as f32;
let df = doc_freq as f32;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
fn calculate_tf(&self, term_freq: u32, doc_length: u32) -> f32 {
let tf = term_freq as f32;
let dl = doc_length as f32;
let avgdl = self.avg_doc_length;
let k1 = self.config.k1;
let b = self.config.b;
let length_norm = 1.0 - b + b * (dl / avgdl);
(tf * (k1 + 1.0)) / (tf + k1 * length_norm)
}
pub fn stats(&self) -> FullTextStats {
FullTextStats {
document_count: self.doc_count as usize,
unique_terms: self.index.len(),
avg_document_length: self.avg_doc_length,
total_postings: self.index.values().map(|v| v.len()).sum(),
}
}
pub fn contains(&self, doc_id: &str) -> bool {
self.doc_lengths.contains_key(doc_id)
}
pub fn get_metadata(&self, doc_id: &str) -> Option<&serde_json::Value> {
self.doc_metadata.get(doc_id)
}
pub fn search_with_filter(
&self,
query: &str,
top_k: usize,
filter: Option<&common::FilterExpression>,
) -> Vec<FullTextResult> {
let results = self.search(query, top_k * 2);
if let Some(filter_expr) = filter {
use crate::filter::evaluate_filter;
results
.into_iter()
.filter(|r| evaluate_filter(filter_expr, self.doc_metadata.get(&r.doc_id)))
.take(top_k)
.collect()
} else {
results.into_iter().take(top_k).collect()
}
}
pub fn len(&self) -> usize {
self.doc_count as usize
}
pub fn is_empty(&self) -> bool {
self.doc_count == 0
}
pub fn clear(&mut self) {
self.index.clear();
self.doc_lengths.clear();
self.doc_count = 0;
self.avg_doc_length = 0.0;
}
}
impl Default for InvertedIndex {
fn default() -> Self {
Self::new(FullTextConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct FullTextResult {
pub doc_id: String,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct FullTextStats {
pub document_count: usize,
pub unique_terms: usize,
pub avg_document_length: f32,
pub total_postings: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_analyzer_basic() {
let analyzer = TextAnalyzer::default();
let tokens = analyzer.analyze("Hello World! This is a test.");
assert!(tokens.contains(&"hello".to_string()));
assert!(tokens.contains(&"world".to_string()));
assert!(tokens.contains(&"test".to_string()));
assert!(!tokens.contains(&"a".to_string())); }
#[test]
fn test_text_analyzer_case_insensitive() {
let analyzer = TextAnalyzer::default();
let tokens = analyzer.analyze("HELLO hello HeLLo");
assert_eq!(tokens.iter().filter(|t| *t == "hello").count(), 3);
}
#[test]
fn test_text_analyzer_token_length() {
let mut config = FullTextConfig::default();
config.min_token_length = 3;
config.max_token_length = 5;
let analyzer = TextAnalyzer::new(config);
let tokens = analyzer.analyze("a ab abc abcd abcde abcdef");
assert!(!tokens.contains(&"a".to_string()));
assert!(!tokens.contains(&"ab".to_string()));
assert!(tokens.contains(&"abc".to_string()));
assert!(tokens.contains(&"abcd".to_string()));
assert!(tokens.contains(&"abcde".to_string()));
assert!(!tokens.contains(&"abcdef".to_string()));
}
#[test]
fn test_token_frequencies() {
let analyzer = TextAnalyzer::default();
let freqs = analyzer.token_frequencies("hello hello world hello");
assert_eq!(freqs.get("hello"), Some(&3));
assert_eq!(freqs.get("world"), Some(&1));
}
#[test]
fn test_inverted_index_add_and_search() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "The quick brown fox jumps over the lazy dog");
index.add_document("doc2", "A quick brown dog runs in the park");
index.add_document("doc3", "The lazy cat sleeps all day");
let results = index.search("quick brown", 10);
assert!(!results.is_empty());
assert!(results.iter().any(|r| r.doc_id == "doc1"));
assert!(results.iter().any(|r| r.doc_id == "doc2"));
assert!(!results.iter().any(|r| r.doc_id == "doc3"));
}
#[test]
fn test_inverted_index_ranking() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "rust is awesome, rust programming");
index.add_document("doc2", "rust programming language");
index.add_document("doc3", "python programming language");
let results = index.search("rust", 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].doc_id, "doc1");
assert_eq!(results[1].doc_id, "doc2");
assert!(results[0].score > results[1].score);
}
#[test]
fn test_inverted_index_remove() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "hello world");
index.add_document("doc2", "hello universe");
assert_eq!(index.len(), 2);
let removed = index.remove_document("doc1");
assert!(removed);
assert_eq!(index.len(), 1);
let results = index.search("hello", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, "doc2");
}
#[test]
fn test_inverted_index_update() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "original content about cats");
let results1 = index.search("cats", 10);
assert_eq!(results1.len(), 1);
index.add_document("doc1", "updated content about dogs");
let results2 = index.search("cats", 10);
assert_eq!(results2.len(), 0);
let results3 = index.search("dogs", 10);
assert_eq!(results3.len(), 1);
}
#[test]
fn test_inverted_index_empty_query() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "hello world");
let results = index.search("the is a", 10);
assert!(results.is_empty());
}
#[test]
fn test_inverted_index_stats() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "hello world test");
index.add_document("doc2", "hello universe example");
let stats = index.stats();
assert_eq!(stats.document_count, 2);
assert!(stats.unique_terms > 0);
assert!(stats.avg_document_length > 0.0);
assert!(stats.total_postings > 0);
}
#[test]
fn test_inverted_index_clear() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "hello world");
index.add_document("doc2", "hello universe");
assert_eq!(index.len(), 2);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
assert_eq!(index.stats().unique_terms, 0);
}
#[test]
fn test_bm25_idf() {
let _index = InvertedIndex::default();
}
#[test]
fn test_bm25_length_normalization() {
let mut index = InvertedIndex::default();
index.add_document("short", "rust");
index.add_document(
"long",
"rust programming language framework library ecosystem tools community",
);
let results = index.search("rust", 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_contains() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "hello world");
assert!(index.contains("doc1"));
assert!(!index.contains("doc2"));
}
#[test]
fn test_custom_config() {
let config = FullTextConfig {
k1: 1.5,
b: 0.5,
min_token_length: 1,
max_token_length: 100,
lowercase: false,
stop_words: HashSet::new(),
stem_language: StemLanguage::None,
};
let mut index = InvertedIndex::new(config);
index.add_document("doc1", "A B C");
let results = index.search("A", 10);
assert_eq!(results.len(), 1);
}
#[test]
fn test_special_characters() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "hello@world.com test-case under_score");
let results = index.search("hello", 10);
assert_eq!(results.len(), 1);
let results = index.search("world", 10);
assert_eq!(results.len(), 1);
let results = index.search("test", 10);
assert_eq!(results.len(), 1);
}
#[test]
fn test_numeric_tokens() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "version 123 release 2024");
let results = index.search("123", 10);
assert_eq!(results.len(), 1);
let results = index.search("2024", 10);
assert_eq!(results.len(), 1);
}
#[test]
fn test_phrase_search_basic() {
let mut index = InvertedIndex::default();
index.add_document("doc1", "quick brown fox");
index.add_document("doc2", "brown quick fox");
let results = index.search("quick brown", 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_english_stemming() {
let config = FullTextConfig::with_english_stemming();
let analyzer = TextAnalyzer::new(config);
let tokens = analyzer.analyze("The cats were running and jumping");
assert!(tokens.contains(&"cat".to_string())); assert!(tokens.contains(&"run".to_string())); assert!(tokens.contains(&"jump".to_string())); }
#[test]
fn test_english_stemming_search() {
let config = FullTextConfig::with_english_stemming();
let mut index = InvertedIndex::new(config);
index.add_document("doc1", "The programmer is programming applications");
index.add_document("doc2", "Software development requires developers");
index.add_document("doc3", "Cooking recipes for beginners");
let results = index.search("program", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, "doc1");
let results = index.search("develop", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, "doc2");
let results = index.search("programming", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, "doc1");
}
#[test]
fn test_german_stemming() {
let config = FullTextConfig::with_language(StemLanguage::German);
let analyzer = TextAnalyzer::new(config);
let _tokens = analyzer.analyze("Die Entwickler entwickeln Software");
assert!(analyzer.stemming_enabled());
assert_eq!(analyzer.stem_language(), StemLanguage::German);
}
#[test]
fn test_french_stemming() {
let config = FullTextConfig::with_language(StemLanguage::French);
let analyzer = TextAnalyzer::new(config);
let tokens = analyzer.analyze("Les programmeurs programment des applications");
assert!(analyzer.stemming_enabled());
assert_eq!(analyzer.stem_language(), StemLanguage::French);
assert!(!tokens.is_empty());
}
#[test]
fn test_spanish_stemming() {
let config = FullTextConfig::with_language(StemLanguage::Spanish);
let analyzer = TextAnalyzer::new(config);
let tokens = analyzer.analyze("Los desarrolladores desarrollan aplicaciones");
assert!(analyzer.stemming_enabled());
assert_eq!(analyzer.stem_language(), StemLanguage::Spanish);
assert!(!tokens.is_empty());
}
#[test]
fn test_no_stemming_default() {
let analyzer = TextAnalyzer::default();
assert!(!analyzer.stemming_enabled());
assert_eq!(analyzer.stem_language(), StemLanguage::None);
let tokens = analyzer.analyze("running jumped cats");
assert!(tokens.contains(&"running".to_string()));
assert!(tokens.contains(&"jumped".to_string()));
assert!(tokens.contains(&"cats".to_string()));
}
#[test]
fn test_stem_language_from_str() {
assert_eq!(
StemLanguage::parse_str("english"),
Some(StemLanguage::English)
);
assert_eq!(StemLanguage::parse_str("en"), Some(StemLanguage::English));
assert_eq!(
StemLanguage::parse_str("ENGLISH"),
Some(StemLanguage::English)
);
assert_eq!(
StemLanguage::parse_str("french"),
Some(StemLanguage::French)
);
assert_eq!(StemLanguage::parse_str("fr"), Some(StemLanguage::French));
assert_eq!(
StemLanguage::parse_str("german"),
Some(StemLanguage::German)
);
assert_eq!(StemLanguage::parse_str("de"), Some(StemLanguage::German));
assert_eq!(
StemLanguage::parse_str("spanish"),
Some(StemLanguage::Spanish)
);
assert_eq!(StemLanguage::parse_str("es"), Some(StemLanguage::Spanish));
assert_eq!(StemLanguage::parse_str("none"), Some(StemLanguage::None));
assert_eq!(StemLanguage::parse_str(""), Some(StemLanguage::None));
assert_eq!(StemLanguage::parse_str("invalid"), None);
}
#[test]
fn test_supported_languages() {
let languages = StemLanguage::supported_languages();
assert!(languages.contains(&"english"));
assert!(languages.contains(&"french"));
assert!(languages.contains(&"german"));
assert!(languages.contains(&"spanish"));
assert!(languages.contains(&"italian"));
assert!(languages.contains(&"portuguese"));
assert!(languages.contains(&"russian"));
assert!(languages.contains(&"arabic"));
assert_eq!(languages.len(), 18); }
}