1use crate::error::Result;
36use crate::types::{Document, Value};
37use dashmap::DashMap;
38use fst::automaton::Levenshtein;
39use fst::{IntoStreamer, Streamer};
40use std::collections::{HashMap, HashSet};
41use std::sync::Arc;
42use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
43use unicode_segmentation::UnicodeSegmentation;
44
45#[allow(dead_code)]
53pub struct FullTextIndex {
54 field: String,
56
57 index: Arc<DashMap<String, Vec<(String, f32)>>>,
59
60 doc_lengths: Arc<DashMap<String, usize>>,
62 document_count: Arc<AtomicUsize>,
63 total_term_count: Arc<AtomicU64>,
64
65 term_fst: Arc<std::sync::RwLock<Option<fst::Set<Vec<u8>>>>>,
67
68 stop_words: HashSet<String>,
70 enable_stop_words: bool,
71
72 k1: f32, b: f32, }
76
77impl FullTextIndex {
78 pub fn new(_collection: &str, field: &str) -> Self {
89 let mut index = Self {
90 field: field.to_string(),
91 index: Arc::new(DashMap::new()),
92 doc_lengths: Arc::new(DashMap::new()),
93 document_count: Arc::new(AtomicUsize::new(0)),
94 total_term_count: Arc::new(AtomicU64::new(0)),
95 term_fst: Arc::new(std::sync::RwLock::new(None)),
96 stop_words: Self::default_stop_words(),
97 enable_stop_words: true,
98 k1: 1.2, b: 0.75, };
101
102 index.stop_words = Self::default_stop_words();
104 index
105 }
106
107 pub fn with_options(
125 collection: &str,
126 field: &str,
127 k1: f32,
128 b: f32,
129 enable_stop_words: bool,
130 ) -> Self {
131 let mut index = Self::new(collection, field);
132 index.k1 = k1;
133 index.b = b;
134 index.enable_stop_words = enable_stop_words;
135 index
136 }
137
138 fn default_stop_words() -> HashSet<String> {
140 let stop_words = vec![
141 "a",
142 "an",
143 "and",
144 "are",
145 "as",
146 "at",
147 "be",
148 "by",
149 "for",
150 "from",
151 "has",
152 "he",
153 "in",
154 "is",
155 "it",
156 "its",
157 "of",
158 "on",
159 "that",
160 "the",
161 "to",
162 "was",
163 "were",
164 "will",
165 "with",
166 "would",
167 "you",
168 "your",
169 "i",
170 "me",
171 "my",
172 "we",
173 "us",
174 "our",
175 "they",
176 "them",
177 "their",
178 "she",
179 "her",
180 "him",
181 "his",
182 "this",
183 "these",
184 "those",
185 "but",
186 "or",
187 "not",
188 "can",
189 "could",
190 "should",
191 "would",
192 "have",
193 "had",
194 "do",
195 "does",
196 "did",
197 "am",
198 "been",
199 "being",
200 "get",
201 "got",
202 "go",
203 "goes",
204 "went",
205 "come",
206 "came",
207 "see",
208 "saw",
209 "know",
210 "knew",
211 "think",
212 "thought",
213 "say",
214 "said",
215 "tell",
216 "told",
217 "take",
218 "took",
219 "give",
220 "gave",
221 "make",
222 "made",
223 "find",
224 "found",
225 "use",
226 "used",
227 "work",
228 "worked",
229 "call",
230 "called",
231 "try",
232 "tried",
233 "ask",
234 "asked",
235 "need",
236 "needed",
237 "feel",
238 "felt",
239 "become",
240 "became",
241 "leave",
242 "left",
243 "put",
244 "puts",
245 "seem",
246 "seemed",
247 "turn",
248 "turned",
249 "start",
250 "started",
251 "show",
252 "showed",
253 "hear",
254 "heard",
255 "play",
256 "played",
257 "run",
258 "ran",
259 "move",
260 "moved",
261 "live",
262 "lived",
263 "believe",
264 "believed",
265 "hold",
266 "held",
267 "bring",
268 "brought",
269 "happen",
270 "happened",
271 "write",
272 "wrote",
273 "provide",
274 "provided",
275 "sit",
276 "sat",
277 "stand",
278 "stood",
279 "lose",
280 "lost",
281 "pay",
282 "paid",
283 "meet",
284 "met",
285 "include",
286 "included",
287 "continue",
288 "continued",
289 "set",
290 "sets",
291 "learn",
292 "learned",
293 "change",
294 "changed",
295 "lead",
296 "led",
297 "understand",
298 "understood",
299 "watch",
300 "watched",
301 "follow",
302 "followed",
303 "stop",
304 "stopped",
305 "create",
306 "created",
307 "speak",
308 "spoke",
309 "read",
310 "reads",
311 "allow",
312 "allowed",
313 "add",
314 "added",
315 "spend",
316 "spent",
317 "grow",
318 "grew",
319 "open",
320 "opened",
321 "walk",
322 "walked",
323 "win",
324 "won",
325 "offer",
326 "offered",
327 "remember",
328 "remembered",
329 "love",
330 "loved",
331 "consider",
332 "considered",
333 "appear",
334 "appeared",
335 "buy",
336 "bought",
337 "wait",
338 "waited",
339 "serve",
340 "served",
341 "die",
342 "died",
343 "send",
344 "sent",
345 "expect",
346 "expected",
347 "build",
348 "built",
349 "stay",
350 "stayed",
351 "fall",
352 "fell",
353 "cut",
354 "cuts",
355 "reach",
356 "reached",
357 "kill",
358 "killed",
359 "remain",
360 "remained",
361 ];
362
363 stop_words.into_iter().map(|s| s.to_string()).collect()
364 }
365
366 pub fn index_document(&self, doc: &Document) -> Result<()> {
379 if let Some(Value::String(text)) = doc.data.get(&self.field) {
380 self.remove_document_internal(&doc.id);
382
383 let terms = self.tokenize(text);
384 if terms.is_empty() {
385 return Ok(());
386 }
387
388 let doc_length = terms.len();
390 self.doc_lengths.insert(doc.id.clone(), doc_length);
391 self.document_count.fetch_add(1, Ordering::Relaxed);
392 self.total_term_count
393 .fetch_add(doc_length as u64, Ordering::Relaxed);
394
395 let term_freq = self.calculate_term_frequencies(&terms);
397
398 for (term, freq) in term_freq {
400 self.index
401 .entry(term.clone())
402 .or_default()
403 .push((doc.id.clone(), freq));
404 }
405
406 self.rebuild_fst();
408 }
409 Ok(())
410 }
411
412 pub fn search(&self, query: &str) -> Vec<(String, f32)> {
428 let query_terms = self.tokenize(query);
429 if query_terms.is_empty() {
430 return Vec::new();
431 }
432
433 let mut scores: HashMap<String, f32> = HashMap::new();
434 let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
435 let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
436
437 for term in query_terms {
438 let Some(term_entry) = self.index.get(&term) else {
439 continue;
440 };
441
442 let doc_freq = term_entry.len() as f32;
444 let idf = if doc_freq >= total_docs {
446 0.1
448 } else {
449 ((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
450 .ln()
451 .max(0.1)
452 };
453
454 for (doc_id, tf) in term_entry.iter() {
456 let doc_len = self
457 .doc_lengths
458 .get(doc_id)
459 .map(|entry| *entry.value() as f32)
460 .unwrap_or(avg_doc_len);
461
462 let tf_numerator = tf * (self.k1 + 1.0);
464 let tf_denominator =
465 tf + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
466 let tf_bm25 = tf_numerator / tf_denominator;
467
468 *scores.entry(doc_id.clone()).or_insert(0.0) += tf_bm25 * idf;
469 }
470 }
471
472 let mut results: Vec<_> = scores.into_iter().collect();
474 results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
475 results
476 }
477
478 pub fn fuzzy_search(&self, query: &str, max_distance: usize) -> Vec<(String, f32)> {
496 let query_terms = self.tokenize(query);
497 if query_terms.is_empty() {
498 return Vec::new();
499 }
500
501 let mut scores: HashMap<String, f32> = HashMap::new();
502 let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
503 let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
504
505 if let Ok(fst_guard) = self.term_fst.read()
507 && let Some(ref fst) = *fst_guard {
508 for query_term in query_terms {
509 if let Ok(lev) = Levenshtein::new(&query_term, max_distance as u32) {
511 let mut stream = fst.search(lev).into_stream();
512
513 while let Some(term_bytes) = stream.next() {
515 let term = String::from_utf8_lossy(term_bytes);
516
517 if let Some(term_entry) = self.index.get(term.as_ref()) {
518 let distance = self.levenshtein_distance(&query_term, &term) as f32;
520 let distance_penalty = 1.0 / (1.0 + distance * 0.3); let doc_freq = term_entry.len() as f32;
524 let idf = if doc_freq >= total_docs {
525 0.1
527 } else {
528 ((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
529 .ln()
530 .max(0.1)
531 };
532
533 for (doc_id, tf) in term_entry.iter() {
534 let doc_len = self
535 .doc_lengths
536 .get(doc_id)
537 .map(|entry| *entry.value() as f32)
538 .unwrap_or(avg_doc_len);
539
540 let tf_numerator = tf * (self.k1 + 1.0);
541 let tf_denominator = tf
542 + self.k1
543 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
544 let tf_bm25 = tf_numerator / tf_denominator;
545
546 let score = tf_bm25 * idf * distance_penalty;
547 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
548 }
549 }
550 }
551 }
552 }
553 }
554
555 let mut results: Vec<_> = scores.into_iter().collect();
557 results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
558 results
559 }
560
561 pub fn remove_document(&self, doc_id: &str) {
574 self.remove_document_internal(doc_id);
575 self.rebuild_fst();
576 }
577
578 fn remove_document_internal(&self, doc_id: &str) {
580 if let Some((_, doc_length)) = self.doc_lengths.remove(doc_id) {
582 self.document_count.fetch_sub(1, Ordering::Relaxed);
583 self.total_term_count
584 .fetch_sub(doc_length as u64, Ordering::Relaxed);
585 }
586
587 let mut empty_terms = Vec::new();
591
592 for mut entry in self.index.iter_mut() {
593 let term = entry.key().clone();
594 let doc_entries = entry.value_mut();
595
596 doc_entries.retain(|(id, _)| id != doc_id);
597
598 if doc_entries.is_empty() {
599 empty_terms.push(term);
600 }
601 }
602
603 for term in empty_terms {
605 self.index.remove(&term);
606 }
607 }
608
609 fn rebuild_fst(&self) {
611 let terms: Vec<String> = self.index.iter().map(|entry| entry.key().clone()).collect();
612
613 if !terms.is_empty() {
614 let mut sorted_terms = terms;
615 sorted_terms.sort();
616
617 if let Ok(fst) = fst::Set::from_iter(sorted_terms)
618 && let Ok(mut fst_guard) = self.term_fst.write() {
619 *fst_guard = Some(fst);
620 }
621 }
622 }
623
624 pub fn highlight_matches(&self, text: &str, query: &str) -> String {
642 let query_terms: Vec<String> = self
643 .tokenize(query)
644 .into_iter()
645 .map(|t| regex::escape(&t))
646 .collect();
647
648 if query_terms.is_empty() {
649 return text.to_string();
650 }
651
652 let pattern = format!("({})", query_terms.join("|"));
654
655 match regex::RegexBuilder::new(&pattern)
656 .case_insensitive(true)
657 .build()
658 {
659 Ok(re) => re.replace_all(text, "<mark>$0</mark>").to_string(),
660 Err(_) => text.to_string(),
661 }
662 }
663
664 fn tokenize(&self, text: &str) -> Vec<String> {
666 let mut tokens: Vec<String> = text
667 .unicode_words()
668 .map(|word| word.to_lowercase())
669 .filter(|word| word.len() > 1) .collect();
671
672 if self.enable_stop_words {
674 tokens.retain(|token| !self.stop_words.contains(token));
675 }
676
677 tokens
678 }
679
680 fn calculate_term_frequencies(&self, terms: &[String]) -> HashMap<String, f32> {
682 let mut freq = HashMap::new();
683 for term in terms {
684 *freq.entry(term.clone()).or_insert(0.0) += 1.0;
685 }
686 freq
687 }
688
689 fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
691 let s1_chars: Vec<char> = s1.chars().collect();
692 let s2_chars: Vec<char> = s2.chars().collect();
693 let s1_len = s1_chars.len();
694 let s2_len = s2_chars.len();
695
696 if s1_len == 0 {
697 return s2_len;
698 }
699 if s2_len == 0 {
700 return s1_len;
701 }
702
703 let mut prev_row = (0..=s2_len).collect::<Vec<_>>();
704 let mut curr_row = vec![0; s2_len + 1];
705
706 for i in 1..=s1_len {
707 curr_row[0] = i;
708
709 for j in 1..=s2_len {
710 let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
711 0
712 } else {
713 1
714 };
715 curr_row[j] = std::cmp::min(
716 curr_row[j - 1] + 1, std::cmp::min(
718 prev_row[j] + 1, prev_row[j - 1] + cost, ),
721 );
722 }
723
724 std::mem::swap(&mut prev_row, &mut curr_row);
725 }
726
727 prev_row[s2_len]
728 }
729
730 pub fn set_stop_words_filter(&mut self, enable: bool) {
732 self.enable_stop_words = enable;
733 }
734
735 pub fn add_stopwords(&mut self, words: &[&str]) {
737 for word in words {
738 self.stop_words.insert(word.to_lowercase());
739 }
740 }
741
742 pub fn clear_stopwords(&mut self) {
744 self.stop_words.clear();
745 }
746
747 pub fn get_stopwords(&self) -> Vec<String> {
749 self.stop_words.iter().cloned().collect()
750 }
751
752 pub fn get_stats(&self) -> IndexStats {
754 IndexStats {
755 total_documents: self.document_count.load(Ordering::Relaxed),
756 total_terms: self.index.len(),
757 total_term_instances: self.total_term_count.load(Ordering::Relaxed),
758 average_document_length: if self.document_count.load(Ordering::Relaxed) > 0 {
759 self.total_term_count.load(Ordering::Relaxed) as f32
760 / self.document_count.load(Ordering::Relaxed) as f32
761 } else {
762 0.0
763 },
764 }
765 }
766
767 pub fn get_fuzzy_matches(&self, query: &str, max_distance: u32) -> Vec<(String, f32)> {
769 self.fuzzy_search(query, max_distance as usize)
770 }
771}
772
773#[derive(Debug, Clone)]
775pub struct IndexStats {
776 pub total_documents: usize,
777 pub total_terms: usize,
778 pub total_term_instances: u64,
779 pub average_document_length: f32,
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785 use std::collections::HashMap;
786
787 #[test]
788 fn test_enhanced_search() -> Result<()> {
789 let index = FullTextIndex::new("test_collection", "content");
790
791 let doc1 = Document {
793 id: "1".to_string(),
794 data: {
795 let mut map = HashMap::new();
796 map.insert(
797 "content".to_string(),
798 Value::String("wireless bluetooth headphones".to_string()),
799 );
800 map
801 },
802 };
803
804 let doc2 = Document {
805 id: "2".to_string(),
806 data: {
807 let mut map = HashMap::new();
808 map.insert("content".to_string(), Value::String(
809 "wireless bluetooth headphones with excellent sound quality and noise cancellation technology for music lovers".to_string()
810 ));
811 map
812 },
813 };
814
815 index.index_document(&doc1)?;
816 index.index_document(&doc2)?;
817
818 let results = index.search("wireless");
820 assert_eq!(results.len(), 2);
821
822 let doc1_score = results.iter().find(|(id, _)| id == "1").unwrap().1;
824 let doc2_score = results.iter().find(|(id, _)| id == "2").unwrap().1;
825
826 assert!(
827 doc1_score > doc2_score,
828 "Shorter document should have higher BM25 score: {} vs {}",
829 doc1_score,
830 doc2_score
831 );
832
833 Ok(())
834 }
835
836 #[test]
837 fn test_fuzzy_search() -> Result<()> {
838 let index = FullTextIndex::new("test", "content");
839
840 let doc = Document {
841 id: "1".to_string(),
842 data: {
843 let mut map = HashMap::new();
844 map.insert(
845 "content".to_string(),
846 Value::String("wireless bluetooth technology".to_string()),
847 );
848 map
849 },
850 };
851
852 index.index_document(&doc)?;
853
854 let results = index.fuzzy_search("wireles", 2);
856 assert!(!results.is_empty(), "Should find fuzzy matches");
857
858 let results = index.fuzzy_search("bluetoth", 2);
859 assert!(!results.is_empty(), "Should find fuzzy matches");
860
861 Ok(())
862 }
863
864 #[test]
865 fn test_stop_words() -> Result<()> {
866 let index = FullTextIndex::new("test", "content");
867
868 let doc = Document {
869 id: "1".to_string(),
870 data: {
871 let mut map = HashMap::new();
872 map.insert(
873 "content".to_string(),
874 Value::String("the quick brown fox".to_string()),
875 );
876 map
877 },
878 };
879
880 index.index_document(&doc)?;
881
882 let results = index.search("the");
884 assert!(results.is_empty(), "Stop words should be filtered");
885
886 let results = index.search("quick");
887 assert!(!results.is_empty(), "Non-stop words should be found");
888
889 Ok(())
890 }
891
892 #[test]
893 fn test_highlight_matches() -> Result<()> {
894 let index = FullTextIndex::new("test", "content");
895
896 let text = "This is a wireless bluetooth device";
897 let highlighted = index.highlight_matches(text, "wireless bluetooth");
898
899 assert!(highlighted.contains("<mark>wireless</mark>"));
900 assert!(highlighted.contains("<mark>bluetooth</mark>"));
901
902 Ok(())
903 }
904
905 #[test]
906 fn test_document_removal() -> Result<()> {
907 let index = FullTextIndex::new("test", "content");
908
909 let doc1 = Document {
910 id: "1".to_string(),
911 data: {
912 let mut map = HashMap::new();
913 map.insert(
914 "content".to_string(),
915 Value::String("wireless technology".to_string()),
916 );
917 map
918 },
919 };
920
921 let doc2 = Document {
922 id: "2".to_string(),
923 data: {
924 let mut map = HashMap::new();
925 map.insert(
926 "content".to_string(),
927 Value::String("bluetooth technology".to_string()),
928 );
929 map
930 },
931 };
932
933 index.index_document(&doc1)?;
934 index.index_document(&doc2)?;
935
936 let results = index.search("technology");
938 assert_eq!(results.len(), 2);
939
940 index.remove_document("1");
942
943 let results = index.search("technology");
945 assert_eq!(results.len(), 1);
946 assert_eq!(results[0].0, "2");
947
948 Ok(())
949 }
950
951 #[test]
952 fn test_levenshtein_distance() -> Result<()> {
953 let index = FullTextIndex::new("test", "content");
954
955 assert_eq!(index.levenshtein_distance("kitten", "sitting"), 3);
956 assert_eq!(index.levenshtein_distance("saturday", "sunday"), 3);
957 assert_eq!(index.levenshtein_distance("", "abc"), 3);
958 assert_eq!(index.levenshtein_distance("abc", ""), 3);
959 assert_eq!(index.levenshtein_distance("same", "same"), 0);
960
961 Ok(())
962 }
963}