1use crate::error::Result;
36use crate::types::{Document, Value};
37use dashmap::DashMap;
38use fst::automaton::Levenshtein;
39use fst::{IntoStreamer, Streamer};
40use std::collections::{HashMap, HashSet};
41use std::sync::Arc;
42use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
43use unicode_segmentation::UnicodeSegmentation;
44
45#[allow(dead_code)]
53pub struct FullTextIndex {
54 field: String,
56
57 index: Arc<DashMap<String, Vec<(String, f32)>>>,
59
60 doc_lengths: Arc<DashMap<String, usize>>,
62 document_count: Arc<AtomicUsize>,
63 total_term_count: Arc<AtomicU64>,
64
65 term_fst: Arc<std::sync::RwLock<Option<fst::Set<Vec<u8>>>>>,
67
68 stop_words: HashSet<String>,
70 enable_stop_words: bool,
71
72 k1: f32, b: f32, }
76
77impl FullTextIndex {
78 pub fn new(_collection: &str, field: &str) -> Self {
89 let mut index = Self {
90 field: field.to_string(),
91 index: Arc::new(DashMap::new()),
92 doc_lengths: Arc::new(DashMap::new()),
93 document_count: Arc::new(AtomicUsize::new(0)),
94 total_term_count: Arc::new(AtomicU64::new(0)),
95 term_fst: Arc::new(std::sync::RwLock::new(None)),
96 stop_words: Self::default_stop_words(),
97 enable_stop_words: true,
98 k1: 1.2, b: 0.75, };
101
102 index.stop_words = Self::default_stop_words();
104 index
105 }
106
107 pub fn with_options(
125 collection: &str,
126 field: &str,
127 k1: f32,
128 b: f32,
129 enable_stop_words: bool,
130 ) -> Self {
131 let mut index = Self::new(collection, field);
132 index.k1 = k1;
133 index.b = b;
134 index.enable_stop_words = enable_stop_words;
135 index
136 }
137
138 fn default_stop_words() -> HashSet<String> {
140 let stop_words = vec![
141 "a",
142 "an",
143 "and",
144 "are",
145 "as",
146 "at",
147 "be",
148 "by",
149 "for",
150 "from",
151 "has",
152 "he",
153 "in",
154 "is",
155 "it",
156 "its",
157 "of",
158 "on",
159 "that",
160 "the",
161 "to",
162 "was",
163 "were",
164 "will",
165 "with",
166 "would",
167 "you",
168 "your",
169 "i",
170 "me",
171 "my",
172 "we",
173 "us",
174 "our",
175 "they",
176 "them",
177 "their",
178 "she",
179 "her",
180 "him",
181 "his",
182 "this",
183 "these",
184 "those",
185 "but",
186 "or",
187 "not",
188 "can",
189 "could",
190 "should",
191 "would",
192 "have",
193 "had",
194 "do",
195 "does",
196 "did",
197 "am",
198 "been",
199 "being",
200 "get",
201 "got",
202 "go",
203 "goes",
204 "went",
205 "come",
206 "came",
207 "see",
208 "saw",
209 "know",
210 "knew",
211 "think",
212 "thought",
213 "say",
214 "said",
215 "tell",
216 "told",
217 "take",
218 "took",
219 "give",
220 "gave",
221 "make",
222 "made",
223 "find",
224 "found",
225 "use",
226 "used",
227 "work",
228 "worked",
229 "call",
230 "called",
231 "try",
232 "tried",
233 "ask",
234 "asked",
235 "need",
236 "needed",
237 "feel",
238 "felt",
239 "become",
240 "became",
241 "leave",
242 "left",
243 "put",
244 "puts",
245 "seem",
246 "seemed",
247 "turn",
248 "turned",
249 "start",
250 "started",
251 "show",
252 "showed",
253 "hear",
254 "heard",
255 "play",
256 "played",
257 "run",
258 "ran",
259 "move",
260 "moved",
261 "live",
262 "lived",
263 "believe",
264 "believed",
265 "hold",
266 "held",
267 "bring",
268 "brought",
269 "happen",
270 "happened",
271 "write",
272 "wrote",
273 "provide",
274 "provided",
275 "sit",
276 "sat",
277 "stand",
278 "stood",
279 "lose",
280 "lost",
281 "pay",
282 "paid",
283 "meet",
284 "met",
285 "include",
286 "included",
287 "continue",
288 "continued",
289 "set",
290 "sets",
291 "learn",
292 "learned",
293 "change",
294 "changed",
295 "lead",
296 "led",
297 "understand",
298 "understood",
299 "watch",
300 "watched",
301 "follow",
302 "followed",
303 "stop",
304 "stopped",
305 "create",
306 "created",
307 "speak",
308 "spoke",
309 "read",
310 "reads",
311 "allow",
312 "allowed",
313 "add",
314 "added",
315 "spend",
316 "spent",
317 "grow",
318 "grew",
319 "open",
320 "opened",
321 "walk",
322 "walked",
323 "win",
324 "won",
325 "offer",
326 "offered",
327 "remember",
328 "remembered",
329 "love",
330 "loved",
331 "consider",
332 "considered",
333 "appear",
334 "appeared",
335 "buy",
336 "bought",
337 "wait",
338 "waited",
339 "serve",
340 "served",
341 "die",
342 "died",
343 "send",
344 "sent",
345 "expect",
346 "expected",
347 "build",
348 "built",
349 "stay",
350 "stayed",
351 "fall",
352 "fell",
353 "cut",
354 "cuts",
355 "reach",
356 "reached",
357 "kill",
358 "killed",
359 "remain",
360 "remained",
361 ];
362
363 stop_words.into_iter().map(|s| s.to_string()).collect()
364 }
365
366 pub fn index_document(&self, doc: &Document) -> Result<()> {
379 if let Some(Value::String(text)) = doc.data.get(&self.field) {
380 self.remove_document_internal(&doc.id);
382
383 let terms = self.tokenize(text);
384 if terms.is_empty() {
385 return Ok(());
386 }
387
388 let doc_length = terms.len();
390 self.doc_lengths.insert(doc.id.clone(), doc_length);
391 self.document_count.fetch_add(1, Ordering::Relaxed);
392 self.total_term_count
393 .fetch_add(doc_length as u64, Ordering::Relaxed);
394
395 let term_freq = self.calculate_term_frequencies(&terms);
397
398 for (term, freq) in term_freq {
400 self.index
401 .entry(term.clone())
402 .or_default()
403 .push((doc.id.clone(), freq));
404 }
405
406 self.rebuild_fst();
408 }
409 Ok(())
410 }
411
412 pub fn search(&self, query: &str) -> Vec<(String, f32)> {
428 let query_terms = self.tokenize(query);
429 if query_terms.is_empty() {
430 return Vec::new();
431 }
432
433 let mut scores: HashMap<String, f32> = HashMap::new();
434 let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
435 let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
436
437 for term in query_terms {
438 let Some(term_entry) = self.index.get(&term) else {
439 continue;
440 };
441
442 let doc_freq = term_entry.len() as f32;
444 let idf = if doc_freq >= total_docs {
446 0.1
448 } else {
449 ((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
450 .ln()
451 .max(0.1)
452 };
453
454 for (doc_id, tf) in term_entry.iter() {
456 let doc_len = self
457 .doc_lengths
458 .get(doc_id)
459 .map(|entry| *entry.value() as f32)
460 .unwrap_or(avg_doc_len);
461
462 let tf_numerator = tf * (self.k1 + 1.0);
464 let tf_denominator =
465 tf + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
466 let tf_bm25 = tf_numerator / tf_denominator;
467
468 *scores.entry(doc_id.clone()).or_insert(0.0) += tf_bm25 * idf;
469 }
470 }
471
472 let mut results: Vec<_> = scores.into_iter().collect();
474 results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
475 results
476 }
477
478 pub fn fuzzy_search(&self, query: &str, max_distance: usize) -> Vec<(String, f32)> {
496 let query_terms = self.tokenize(query);
497 if query_terms.is_empty() {
498 return Vec::new();
499 }
500
501 let mut scores: HashMap<String, f32> = HashMap::new();
502 let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
503 let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
504
505 if let Ok(fst_guard) = self.term_fst.read() {
507 if let Some(ref fst) = *fst_guard {
508 for query_term in query_terms {
509 if let Ok(lev) = Levenshtein::new(&query_term, max_distance as u32) {
511 let mut stream = fst.search(lev).into_stream();
512
513 while let Some(term_bytes) = stream.next() {
515 let term = String::from_utf8_lossy(term_bytes);
516
517 if let Some(term_entry) = self.index.get(term.as_ref()) {
518 let distance = self.levenshtein_distance(&query_term, &term) as f32;
520 let distance_penalty = 1.0 / (1.0 + distance * 0.3); let doc_freq = term_entry.len() as f32;
524 let idf = if doc_freq >= total_docs {
525 0.1
527 } else {
528 ((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
529 .ln()
530 .max(0.1)
531 };
532
533 for (doc_id, tf) in term_entry.iter() {
534 let doc_len = self
535 .doc_lengths
536 .get(doc_id)
537 .map(|entry| *entry.value() as f32)
538 .unwrap_or(avg_doc_len);
539
540 let tf_numerator = tf * (self.k1 + 1.0);
541 let tf_denominator = tf
542 + self.k1
543 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
544 let tf_bm25 = tf_numerator / tf_denominator;
545
546 let score = tf_bm25 * idf * distance_penalty;
547 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
548 }
549 }
550 }
551 }
552 }
553 }
554 }
555
556 let mut results: Vec<_> = scores.into_iter().collect();
558 results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
559 results
560 }
561
562 pub fn remove_document(&self, doc_id: &str) {
575 self.remove_document_internal(doc_id);
576 self.rebuild_fst();
577 }
578
579 fn remove_document_internal(&self, doc_id: &str) {
581 if let Some((_, doc_length)) = self.doc_lengths.remove(doc_id) {
583 self.document_count.fetch_sub(1, Ordering::Relaxed);
584 self.total_term_count
585 .fetch_sub(doc_length as u64, Ordering::Relaxed);
586 }
587
588 let mut empty_terms = Vec::new();
592
593 for mut entry in self.index.iter_mut() {
594 let term = entry.key().clone();
595 let doc_entries = entry.value_mut();
596
597 doc_entries.retain(|(id, _)| id != doc_id);
598
599 if doc_entries.is_empty() {
600 empty_terms.push(term);
601 }
602 }
603
604 for term in empty_terms {
606 self.index.remove(&term);
607 }
608 }
609
610 fn rebuild_fst(&self) {
612 let terms: Vec<String> = self.index.iter().map(|entry| entry.key().clone()).collect();
613
614 if !terms.is_empty() {
615 let mut sorted_terms = terms;
616 sorted_terms.sort();
617
618 if let Ok(fst) = fst::Set::from_iter(sorted_terms) {
619 if let Ok(mut fst_guard) = self.term_fst.write() {
620 *fst_guard = Some(fst);
621 }
622 }
623 }
624 }
625
626 pub fn highlight_matches(&self, text: &str, query: &str) -> String {
644 let query_terms: Vec<String> = self
645 .tokenize(query)
646 .into_iter()
647 .map(|t| regex::escape(&t))
648 .collect();
649
650 if query_terms.is_empty() {
651 return text.to_string();
652 }
653
654 let pattern = format!("({})", query_terms.join("|"));
656
657 match regex::RegexBuilder::new(&pattern)
658 .case_insensitive(true)
659 .build()
660 {
661 Ok(re) => re.replace_all(text, "<mark>$0</mark>").to_string(),
662 Err(_) => text.to_string(),
663 }
664 }
665
666 fn tokenize(&self, text: &str) -> Vec<String> {
668 let mut tokens: Vec<String> = text
669 .unicode_words()
670 .map(|word| word.to_lowercase())
671 .filter(|word| word.len() > 1) .collect();
673
674 if self.enable_stop_words {
676 tokens.retain(|token| !self.stop_words.contains(token));
677 }
678
679 tokens
680 }
681
682 fn calculate_term_frequencies(&self, terms: &[String]) -> HashMap<String, f32> {
684 let mut freq = HashMap::new();
685 for term in terms {
686 *freq.entry(term.clone()).or_insert(0.0) += 1.0;
687 }
688 freq
689 }
690
691 fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
693 let s1_chars: Vec<char> = s1.chars().collect();
694 let s2_chars: Vec<char> = s2.chars().collect();
695 let s1_len = s1_chars.len();
696 let s2_len = s2_chars.len();
697
698 if s1_len == 0 {
699 return s2_len;
700 }
701 if s2_len == 0 {
702 return s1_len;
703 }
704
705 let mut prev_row = (0..=s2_len).collect::<Vec<_>>();
706 let mut curr_row = vec![0; s2_len + 1];
707
708 for i in 1..=s1_len {
709 curr_row[0] = i;
710
711 for j in 1..=s2_len {
712 let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
713 0
714 } else {
715 1
716 };
717 curr_row[j] = std::cmp::min(
718 curr_row[j - 1] + 1, std::cmp::min(
720 prev_row[j] + 1, prev_row[j - 1] + cost, ),
723 );
724 }
725
726 std::mem::swap(&mut prev_row, &mut curr_row);
727 }
728
729 prev_row[s2_len]
730 }
731
732 pub fn set_stop_words_filter(&mut self, enable: bool) {
734 self.enable_stop_words = enable;
735 }
736
737 pub fn add_stopwords(&mut self, words: &[&str]) {
739 for word in words {
740 self.stop_words.insert(word.to_lowercase());
741 }
742 }
743
744 pub fn clear_stopwords(&mut self) {
746 self.stop_words.clear();
747 }
748
749 pub fn get_stopwords(&self) -> Vec<String> {
751 self.stop_words.iter().cloned().collect()
752 }
753
754 pub fn get_stats(&self) -> IndexStats {
756 IndexStats {
757 total_documents: self.document_count.load(Ordering::Relaxed),
758 total_terms: self.index.len(),
759 total_term_instances: self.total_term_count.load(Ordering::Relaxed),
760 average_document_length: if self.document_count.load(Ordering::Relaxed) > 0 {
761 self.total_term_count.load(Ordering::Relaxed) as f32
762 / self.document_count.load(Ordering::Relaxed) as f32
763 } else {
764 0.0
765 },
766 }
767 }
768
769 pub fn get_fuzzy_matches(&self, query: &str, max_distance: u32) -> Vec<(String, f32)> {
771 self.fuzzy_search(query, max_distance as usize)
772 }
773}
774
775#[derive(Debug, Clone)]
777pub struct IndexStats {
778 pub total_documents: usize,
779 pub total_terms: usize,
780 pub total_term_instances: u64,
781 pub average_document_length: f32,
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787 use std::collections::HashMap;
788
789 #[test]
790 fn test_enhanced_search() -> Result<()> {
791 let index = FullTextIndex::new("test_collection", "content");
792
793 let doc1 = Document {
795 id: "1".to_string(),
796 data: {
797 let mut map = HashMap::new();
798 map.insert(
799 "content".to_string(),
800 Value::String("wireless bluetooth headphones".to_string()),
801 );
802 map
803 },
804 };
805
806 let doc2 = Document {
807 id: "2".to_string(),
808 data: {
809 let mut map = HashMap::new();
810 map.insert("content".to_string(), Value::String(
811 "wireless bluetooth headphones with excellent sound quality and noise cancellation technology for music lovers".to_string()
812 ));
813 map
814 },
815 };
816
817 index.index_document(&doc1)?;
818 index.index_document(&doc2)?;
819
820 let results = index.search("wireless");
822 assert_eq!(results.len(), 2);
823
824 let doc1_score = results.iter().find(|(id, _)| id == "1").unwrap().1;
826 let doc2_score = results.iter().find(|(id, _)| id == "2").unwrap().1;
827
828 assert!(
829 doc1_score > doc2_score,
830 "Shorter document should have higher BM25 score: {} vs {}",
831 doc1_score,
832 doc2_score
833 );
834
835 Ok(())
836 }
837
838 #[test]
839 fn test_fuzzy_search() -> Result<()> {
840 let index = FullTextIndex::new("test", "content");
841
842 let doc = Document {
843 id: "1".to_string(),
844 data: {
845 let mut map = HashMap::new();
846 map.insert(
847 "content".to_string(),
848 Value::String("wireless bluetooth technology".to_string()),
849 );
850 map
851 },
852 };
853
854 index.index_document(&doc)?;
855
856 let results = index.fuzzy_search("wireles", 2);
858 assert!(!results.is_empty(), "Should find fuzzy matches");
859
860 let results = index.fuzzy_search("bluetoth", 2);
861 assert!(!results.is_empty(), "Should find fuzzy matches");
862
863 Ok(())
864 }
865
866 #[test]
867 fn test_stop_words() -> Result<()> {
868 let index = FullTextIndex::new("test", "content");
869
870 let doc = Document {
871 id: "1".to_string(),
872 data: {
873 let mut map = HashMap::new();
874 map.insert(
875 "content".to_string(),
876 Value::String("the quick brown fox".to_string()),
877 );
878 map
879 },
880 };
881
882 index.index_document(&doc)?;
883
884 let results = index.search("the");
886 assert!(results.is_empty(), "Stop words should be filtered");
887
888 let results = index.search("quick");
889 assert!(!results.is_empty(), "Non-stop words should be found");
890
891 Ok(())
892 }
893
894 #[test]
895 fn test_highlight_matches() -> Result<()> {
896 let index = FullTextIndex::new("test", "content");
897
898 let text = "This is a wireless bluetooth device";
899 let highlighted = index.highlight_matches(text, "wireless bluetooth");
900
901 assert!(highlighted.contains("<mark>wireless</mark>"));
902 assert!(highlighted.contains("<mark>bluetooth</mark>"));
903
904 Ok(())
905 }
906
907 #[test]
908 fn test_document_removal() -> Result<()> {
909 let index = FullTextIndex::new("test", "content");
910
911 let doc1 = Document {
912 id: "1".to_string(),
913 data: {
914 let mut map = HashMap::new();
915 map.insert(
916 "content".to_string(),
917 Value::String("wireless technology".to_string()),
918 );
919 map
920 },
921 };
922
923 let doc2 = Document {
924 id: "2".to_string(),
925 data: {
926 let mut map = HashMap::new();
927 map.insert(
928 "content".to_string(),
929 Value::String("bluetooth technology".to_string()),
930 );
931 map
932 },
933 };
934
935 index.index_document(&doc1)?;
936 index.index_document(&doc2)?;
937
938 let results = index.search("technology");
940 assert_eq!(results.len(), 2);
941
942 index.remove_document("1");
944
945 let results = index.search("technology");
947 assert_eq!(results.len(), 1);
948 assert_eq!(results[0].0, "2");
949
950 Ok(())
951 }
952
953 #[test]
954 fn test_levenshtein_distance() -> Result<()> {
955 let index = FullTextIndex::new("test", "content");
956
957 assert_eq!(index.levenshtein_distance("kitten", "sitting"), 3);
958 assert_eq!(index.levenshtein_distance("saturday", "sunday"), 3);
959 assert_eq!(index.levenshtein_distance("", "abc"), 3);
960 assert_eq!(index.levenshtein_distance("abc", ""), 3);
961 assert_eq!(index.levenshtein_distance("same", "same"), 0);
962
963 Ok(())
964 }
965}