use crate::error::Result;
use crate::types::{Document, Value};
use dashmap::DashMap;
use fst::automaton::Levenshtein;
use fst::{IntoStreamer, Streamer};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use unicode_segmentation::UnicodeSegmentation;
pub fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let s1: Vec<char> = s1.chars().collect();
let s2: Vec<char> = s2.chars().collect();
if s1.is_empty() { return s2.len(); }
if s2.is_empty() { return s1.len(); }
let mut prev: Vec<usize> = (0..=s2.len()).collect();
let mut curr = vec![0usize; s2.len() + 1];
for i in 1..=s1.len() {
curr[0] = i;
for j in 1..=s2.len() {
let cost = if s1[i-1] == s2[j-1] { 0 } else { 1 };
curr[j] = (curr[j-1] + 1).min((prev[j] + 1).min(prev[j-1] + cost));
}
std::mem::swap(&mut prev, &mut curr);
}
prev[s2.len()]
}
pub struct FullTextIndex {
field: String,
index: Arc<DashMap<String, Vec<(String, f32)>>>,
doc_lengths: Arc<DashMap<String, usize>>,
document_count: Arc<AtomicUsize>,
total_term_count: Arc<AtomicU64>,
term_fst: Arc<std::sync::RwLock<Option<fst::Set<Vec<u8>>>>>,
stop_words: HashSet<String>,
enable_stop_words: bool,
k1: f32, b: f32, }
impl FullTextIndex {
pub fn new(_collection: &str, field: &str) -> Self {
let mut index = Self {
field: field.to_string(),
index: Arc::new(DashMap::new()),
doc_lengths: Arc::new(DashMap::new()),
document_count: Arc::new(AtomicUsize::new(0)),
total_term_count: Arc::new(AtomicU64::new(0)),
term_fst: Arc::new(std::sync::RwLock::new(None)),
stop_words: Self::default_stop_words(),
enable_stop_words: true,
k1: 1.2, b: 0.75, };
index.stop_words = Self::default_stop_words();
index
}
pub fn with_options(
collection: &str,
field: &str,
k1: f32,
b: f32,
enable_stop_words: bool,
) -> Self {
let mut index = Self::new(collection, field);
index.k1 = k1;
index.b = b;
index.enable_stop_words = enable_stop_words;
index
}
fn default_stop_words() -> HashSet<String> {
let stop_words = vec![
"a",
"an",
"and",
"are",
"as",
"at",
"be",
"by",
"for",
"from",
"has",
"he",
"in",
"is",
"it",
"its",
"of",
"on",
"that",
"the",
"to",
"was",
"were",
"will",
"with",
"would",
"you",
"your",
"i",
"me",
"my",
"we",
"us",
"our",
"they",
"them",
"their",
"she",
"her",
"him",
"his",
"this",
"these",
"those",
"but",
"or",
"not",
"can",
"could",
"should",
"would",
"have",
"had",
"do",
"does",
"did",
"am",
"been",
"being",
"get",
"got",
"go",
"goes",
"went",
"come",
"came",
"see",
"saw",
"know",
"knew",
"think",
"thought",
"say",
"said",
"tell",
"told",
"take",
"took",
"give",
"gave",
"make",
"made",
"find",
"found",
"use",
"used",
"work",
"worked",
"call",
"called",
"try",
"tried",
"ask",
"asked",
"need",
"needed",
"feel",
"felt",
"become",
"became",
"leave",
"left",
"put",
"puts",
"seem",
"seemed",
"turn",
"turned",
"start",
"started",
"show",
"showed",
"hear",
"heard",
"play",
"played",
"run",
"ran",
"move",
"moved",
"live",
"lived",
"believe",
"believed",
"hold",
"held",
"bring",
"brought",
"happen",
"happened",
"write",
"wrote",
"provide",
"provided",
"sit",
"sat",
"stand",
"stood",
"lose",
"lost",
"pay",
"paid",
"meet",
"met",
"include",
"included",
"continue",
"continued",
"set",
"sets",
"learn",
"learned",
"change",
"changed",
"lead",
"led",
"understand",
"understood",
"watch",
"watched",
"follow",
"followed",
"stop",
"stopped",
"create",
"created",
"speak",
"spoke",
"read",
"reads",
"allow",
"allowed",
"add",
"added",
"spend",
"spent",
"grow",
"grew",
"open",
"opened",
"walk",
"walked",
"win",
"won",
"offer",
"offered",
"remember",
"remembered",
"love",
"loved",
"consider",
"considered",
"appear",
"appeared",
"buy",
"bought",
"wait",
"waited",
"serve",
"served",
"die",
"died",
"send",
"sent",
"expect",
"expected",
"build",
"built",
"stay",
"stayed",
"fall",
"fell",
"cut",
"cuts",
"reach",
"reached",
"kill",
"killed",
"remain",
"remained",
];
stop_words.into_iter().map(|s| s.to_string()).collect()
}
pub fn index_document(&self, doc: &Document) -> Result<()> {
if let Some(Value::String(text)) = doc.data.get(&self.field) {
self.remove_document_internal(&doc._sid);
let terms = self.tokenize(text);
if terms.is_empty() {
return Ok(());
}
let doc_length = terms.len();
self.doc_lengths.insert(doc._sid.clone(), doc_length);
self.document_count.fetch_add(1, Ordering::Relaxed);
self.total_term_count
.fetch_add(doc_length as u64, Ordering::Relaxed);
let term_freq = self.calculate_term_frequencies(&terms);
for (term, freq) in term_freq {
self.index
.entry(term.clone())
.or_default()
.push((doc._sid.clone(), freq));
}
self.rebuild_fst();
}
Ok(())
}
pub fn search(&self, query: &str) -> Vec<(String, f32)> {
let query_terms = self.tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let mut scores: HashMap<String, f32> = HashMap::new();
let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
for term in query_terms {
let Some(term_entry) = self.index.get(&term) else {
continue;
};
let doc_freq = term_entry.len() as f32;
let idf = if doc_freq >= total_docs {
0.1
} else {
((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
.ln()
.max(0.1)
};
for (doc_id, tf) in term_entry.iter() {
let doc_len = self
.doc_lengths
.get(doc_id)
.map(|entry| *entry.value() as f32)
.unwrap_or(avg_doc_len);
let tf_numerator = tf * (self.k1 + 1.0);
let tf_denominator =
tf + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
let tf_bm25 = tf_numerator / tf_denominator;
*scores.entry(doc_id.clone()).or_insert(0.0) += tf_bm25 * idf;
}
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
results
}
pub fn fuzzy_search(&self, query: &str, max_distance: usize) -> Vec<(String, f32)> {
let query_terms = self.tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let mut scores: HashMap<String, f32> = HashMap::new();
let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
if let Ok(fst_guard) = self.term_fst.read()
&& let Some(ref fst) = *fst_guard {
for query_term in query_terms {
if let Ok(lev) = Levenshtein::new(&query_term, max_distance as u32) {
let mut stream = fst.search(lev).into_stream();
while let Some(term_bytes) = stream.next() {
let term = String::from_utf8_lossy(term_bytes);
if let Some(term_entry) = self.index.get(term.as_ref()) {
let distance = self.levenshtein_distance(&query_term, &term) as f32;
let distance_penalty = 1.0 / (1.0 + distance * 0.3);
let doc_freq = term_entry.len() as f32;
let idf = if doc_freq >= total_docs {
0.1
} else {
((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
.ln()
.max(0.1)
};
for (doc_id, tf) in term_entry.iter() {
let doc_len = self
.doc_lengths
.get(doc_id)
.map(|entry| *entry.value() as f32)
.unwrap_or(avg_doc_len);
let tf_numerator = tf * (self.k1 + 1.0);
let tf_denominator = tf
+ self.k1
* (1.0 - self.b + self.b * (doc_len / avg_doc_len));
let tf_bm25 = tf_numerator / tf_denominator;
let score = tf_bm25 * idf * distance_penalty;
*scores.entry(doc_id.clone()).or_insert(0.0) += score;
}
}
}
}
}
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
results
}
pub fn remove_document(&self, doc_id: &str) {
self.remove_document_internal(doc_id);
self.rebuild_fst();
}
fn remove_document_internal(&self, doc_id: &str) {
if let Some((_, doc_length)) = self.doc_lengths.remove(doc_id) {
self.document_count.fetch_sub(1, Ordering::Relaxed);
self.total_term_count
.fetch_sub(doc_length as u64, Ordering::Relaxed);
}
let mut empty_terms = Vec::new();
for mut entry in self.index.iter_mut() {
let term = entry.key().clone();
let doc_entries = entry.value_mut();
doc_entries.retain(|(id, _)| id != doc_id);
if doc_entries.is_empty() {
empty_terms.push(term);
}
}
for term in empty_terms {
self.index.remove(&term);
}
}
fn rebuild_fst(&self) {
let terms: Vec<String> = self.index.iter().map(|entry| entry.key().clone()).collect();
if !terms.is_empty() {
let mut sorted_terms = terms;
sorted_terms.sort();
if let Ok(fst) = fst::Set::from_iter(sorted_terms)
&& let Ok(mut fst_guard) = self.term_fst.write() {
*fst_guard = Some(fst);
}
}
}
pub fn highlight_matches(&self, text: &str, query: &str) -> String {
let query_terms: Vec<String> = self
.tokenize(query)
.into_iter()
.map(|t| regex::escape(&t))
.collect();
if query_terms.is_empty() {
return text.to_string();
}
let pattern = format!("({})", query_terms.join("|"));
match regex::RegexBuilder::new(&pattern)
.case_insensitive(true)
.build()
{
Ok(re) => re.replace_all(text, "<mark>$0</mark>").to_string(),
Err(_) => text.to_string(),
}
}
fn tokenize(&self, text: &str) -> Vec<String> {
let mut tokens: Vec<String> = text
.unicode_words()
.map(|word| word.to_lowercase())
.filter(|word| word.len() > 1) .collect();
if self.enable_stop_words {
tokens.retain(|token| !self.stop_words.contains(token));
}
tokens
}
fn calculate_term_frequencies(&self, terms: &[String]) -> HashMap<String, f32> {
let mut freq = HashMap::new();
for term in terms {
*freq.entry(term.clone()).or_insert(0.0) += 1.0;
}
freq
}
fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
let s1_len = s1_chars.len();
let s2_len = s2_chars.len();
if s1_len == 0 {
return s2_len;
}
if s2_len == 0 {
return s1_len;
}
let mut prev_row = (0..=s2_len).collect::<Vec<_>>();
let mut curr_row = vec![0; s2_len + 1];
for i in 1..=s1_len {
curr_row[0] = i;
for j in 1..=s2_len {
let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
0
} else {
1
};
curr_row[j] = std::cmp::min(
curr_row[j - 1] + 1, std::cmp::min(
prev_row[j] + 1, prev_row[j - 1] + cost, ),
);
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
prev_row[s2_len]
}
pub fn set_stop_words_filter(&mut self, enable: bool) {
self.enable_stop_words = enable;
}
pub fn add_stopwords(&mut self, words: &[&str]) {
for word in words {
self.stop_words.insert(word.to_lowercase());
}
}
pub fn clear_stopwords(&mut self) {
self.stop_words.clear();
}
pub fn get_stopwords(&self) -> Vec<String> {
self.stop_words.iter().cloned().collect()
}
pub fn get_stats(&self) -> IndexStats {
IndexStats {
total_documents: self.document_count.load(Ordering::Relaxed),
total_terms: self.index.len(),
total_term_instances: self.total_term_count.load(Ordering::Relaxed),
average_document_length: if self.document_count.load(Ordering::Relaxed) > 0 {
self.total_term_count.load(Ordering::Relaxed) as f32
/ self.document_count.load(Ordering::Relaxed) as f32
} else {
0.0
},
}
}
pub fn get_fuzzy_matches(&self, query: &str, max_distance: u32) -> Vec<(String, f32)> {
self.fuzzy_search(query, max_distance as usize)
}
}
#[derive(Debug, Clone)]
pub struct IndexStats {
pub total_documents: usize,
pub total_terms: usize,
pub total_term_instances: u64,
pub average_document_length: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_enhanced_search() -> Result<()> {
let index = FullTextIndex::new("test_collection", "content");
let doc1 = Document {
_sid: "1".to_string(),
data: {
let mut map = HashMap::new();
map.insert(
"content".to_string(),
Value::String("wireless bluetooth headphones".to_string()),
);
map
},
};
let doc2 = Document {
_sid: "2".to_string(),
data: {
let mut map = HashMap::new();
map.insert("content".to_string(), Value::String(
"wireless bluetooth headphones with excellent sound quality and noise cancellation technology for music lovers".to_string()
));
map
},
};
index.index_document(&doc1)?;
index.index_document(&doc2)?;
let results = index.search("wireless");
assert_eq!(results.len(), 2);
let doc1_score = results.iter().find(|(id, _)| id == "1").unwrap().1;
let doc2_score = results.iter().find(|(id, _)| id == "2").unwrap().1;
assert!(
doc1_score > doc2_score,
"Shorter document should have higher BM25 score: {} vs {}",
doc1_score,
doc2_score
);
Ok(())
}
#[test]
fn test_fuzzy_search() -> Result<()> {
let index = FullTextIndex::new("test", "content");
let doc = Document {
_sid: "1".to_string(),
data: {
let mut map = HashMap::new();
map.insert(
"content".to_string(),
Value::String("wireless bluetooth technology".to_string()),
);
map
},
};
index.index_document(&doc)?;
let results = index.fuzzy_search("wireles", 2);
assert!(!results.is_empty(), "Should find fuzzy matches");
let results = index.fuzzy_search("bluetoth", 2);
assert!(!results.is_empty(), "Should find fuzzy matches");
Ok(())
}
#[test]
fn test_stop_words() -> Result<()> {
let index = FullTextIndex::new("test", "content");
let doc = Document {
_sid: "1".to_string(),
data: {
let mut map = HashMap::new();
map.insert(
"content".to_string(),
Value::String("the quick brown fox".to_string()),
);
map
},
};
index.index_document(&doc)?;
let results = index.search("the");
assert!(results.is_empty(), "Stop words should be filtered");
let results = index.search("quick");
assert!(!results.is_empty(), "Non-stop words should be found");
Ok(())
}
#[test]
fn test_highlight_matches() -> Result<()> {
let index = FullTextIndex::new("test", "content");
let text = "This is a wireless bluetooth device";
let highlighted = index.highlight_matches(text, "wireless bluetooth");
assert!(highlighted.contains("<mark>wireless</mark>"));
assert!(highlighted.contains("<mark>bluetooth</mark>"));
Ok(())
}
#[test]
fn test_document_removal() -> Result<()> {
let index = FullTextIndex::new("test", "content");
let doc1 = Document {
_sid: "1".to_string(),
data: {
let mut map = HashMap::new();
map.insert(
"content".to_string(),
Value::String("wireless technology".to_string()),
);
map
},
};
let doc2 = Document {
_sid: "2".to_string(),
data: {
let mut map = HashMap::new();
map.insert(
"content".to_string(),
Value::String("bluetooth technology".to_string()),
);
map
},
};
index.index_document(&doc1)?;
index.index_document(&doc2)?;
let results = index.search("technology");
assert_eq!(results.len(), 2);
index.remove_document("1");
let results = index.search("technology");
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "2");
Ok(())
}
#[test]
fn test_levenshtein_distance() -> Result<()> {
let index = FullTextIndex::new("test", "content");
assert_eq!(index.levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(index.levenshtein_distance("saturday", "sunday"), 3);
assert_eq!(index.levenshtein_distance("", "abc"), 3);
assert_eq!(index.levenshtein_distance("abc", ""), 3);
assert_eq!(index.levenshtein_distance("same", "same"), 0);
Ok(())
}
}