1use crate::error::Result;
36use crate::types::{Document, Value};
37use dashmap::DashMap;
38use fst::automaton::Levenshtein;
39use fst::{IntoStreamer, Streamer};
40use std::collections::{HashMap, HashSet};
41use std::sync::Arc;
42use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
43use unicode_segmentation::UnicodeSegmentation;
44
45pub fn levenshtein_distance(s1: &str, s2: &str) -> usize {
54 let s1: Vec<char> = s1.chars().collect();
55 let s2: Vec<char> = s2.chars().collect();
56 if s1.is_empty() {
57 return s2.len();
58 }
59 if s2.is_empty() {
60 return s1.len();
61 }
62 let mut prev: Vec<usize> = (0..=s2.len()).collect();
63 let mut curr = vec![0usize; s2.len() + 1];
64 for i in 1..=s1.len() {
65 curr[0] = i;
66 for j in 1..=s2.len() {
67 let cost = if s1[i - 1] == s2[j - 1] { 0 } else { 1 };
68 curr[j] = (curr[j - 1] + 1).min((prev[j] + 1).min(prev[j - 1] + cost));
69 }
70 std::mem::swap(&mut prev, &mut curr);
71 }
72 prev[s2.len()]
73}
74
75pub struct FullTextIndex {
76 field: String,
78
79 index: Arc<DashMap<String, Vec<(String, f32)>>>,
81
82 doc_lengths: Arc<DashMap<String, usize>>,
84 document_count: Arc<AtomicUsize>,
85 total_term_count: Arc<AtomicU64>,
86
87 term_fst: Arc<std::sync::RwLock<Option<fst::Set<Vec<u8>>>>>,
89
90 stop_words: HashSet<String>,
92 enable_stop_words: bool,
93
94 k1: f32, b: f32, }
98
99impl FullTextIndex {
100 pub fn new(_collection: &str, field: &str) -> Self {
111 let mut index = Self {
112 field: field.to_string(),
113 index: Arc::new(DashMap::new()),
114 doc_lengths: Arc::new(DashMap::new()),
115 document_count: Arc::new(AtomicUsize::new(0)),
116 total_term_count: Arc::new(AtomicU64::new(0)),
117 term_fst: Arc::new(std::sync::RwLock::new(None)),
118 stop_words: Self::default_stop_words(),
119 enable_stop_words: true,
120 k1: 1.2, b: 0.75, };
123
124 index.stop_words = Self::default_stop_words();
126 index
127 }
128
129 pub fn with_options(
147 collection: &str,
148 field: &str,
149 k1: f32,
150 b: f32,
151 enable_stop_words: bool,
152 ) -> Self {
153 let mut index = Self::new(collection, field);
154 index.k1 = k1;
155 index.b = b;
156 index.enable_stop_words = enable_stop_words;
157 index
158 }
159
160 fn default_stop_words() -> HashSet<String> {
162 let stop_words = vec![
163 "a",
164 "an",
165 "and",
166 "are",
167 "as",
168 "at",
169 "be",
170 "by",
171 "for",
172 "from",
173 "has",
174 "he",
175 "in",
176 "is",
177 "it",
178 "its",
179 "of",
180 "on",
181 "that",
182 "the",
183 "to",
184 "was",
185 "were",
186 "will",
187 "with",
188 "would",
189 "you",
190 "your",
191 "i",
192 "me",
193 "my",
194 "we",
195 "us",
196 "our",
197 "they",
198 "them",
199 "their",
200 "she",
201 "her",
202 "him",
203 "his",
204 "this",
205 "these",
206 "those",
207 "but",
208 "or",
209 "not",
210 "can",
211 "could",
212 "should",
213 "would",
214 "have",
215 "had",
216 "do",
217 "does",
218 "did",
219 "am",
220 "been",
221 "being",
222 "get",
223 "got",
224 "go",
225 "goes",
226 "went",
227 "come",
228 "came",
229 "see",
230 "saw",
231 "know",
232 "knew",
233 "think",
234 "thought",
235 "say",
236 "said",
237 "tell",
238 "told",
239 "take",
240 "took",
241 "give",
242 "gave",
243 "make",
244 "made",
245 "find",
246 "found",
247 "use",
248 "used",
249 "work",
250 "worked",
251 "call",
252 "called",
253 "try",
254 "tried",
255 "ask",
256 "asked",
257 "need",
258 "needed",
259 "feel",
260 "felt",
261 "become",
262 "became",
263 "leave",
264 "left",
265 "put",
266 "puts",
267 "seem",
268 "seemed",
269 "turn",
270 "turned",
271 "start",
272 "started",
273 "show",
274 "showed",
275 "hear",
276 "heard",
277 "play",
278 "played",
279 "run",
280 "ran",
281 "move",
282 "moved",
283 "live",
284 "lived",
285 "believe",
286 "believed",
287 "hold",
288 "held",
289 "bring",
290 "brought",
291 "happen",
292 "happened",
293 "write",
294 "wrote",
295 "provide",
296 "provided",
297 "sit",
298 "sat",
299 "stand",
300 "stood",
301 "lose",
302 "lost",
303 "pay",
304 "paid",
305 "meet",
306 "met",
307 "include",
308 "included",
309 "continue",
310 "continued",
311 "set",
312 "sets",
313 "learn",
314 "learned",
315 "change",
316 "changed",
317 "lead",
318 "led",
319 "understand",
320 "understood",
321 "watch",
322 "watched",
323 "follow",
324 "followed",
325 "stop",
326 "stopped",
327 "create",
328 "created",
329 "speak",
330 "spoke",
331 "read",
332 "reads",
333 "allow",
334 "allowed",
335 "add",
336 "added",
337 "spend",
338 "spent",
339 "grow",
340 "grew",
341 "open",
342 "opened",
343 "walk",
344 "walked",
345 "win",
346 "won",
347 "offer",
348 "offered",
349 "remember",
350 "remembered",
351 "love",
352 "loved",
353 "consider",
354 "considered",
355 "appear",
356 "appeared",
357 "buy",
358 "bought",
359 "wait",
360 "waited",
361 "serve",
362 "served",
363 "die",
364 "died",
365 "send",
366 "sent",
367 "expect",
368 "expected",
369 "build",
370 "built",
371 "stay",
372 "stayed",
373 "fall",
374 "fell",
375 "cut",
376 "cuts",
377 "reach",
378 "reached",
379 "kill",
380 "killed",
381 "remain",
382 "remained",
383 ];
384
385 stop_words.into_iter().map(|s| s.to_string()).collect()
386 }
387
388 pub fn index_document(&self, doc: &Document) -> Result<()> {
401 if let Some(Value::String(text)) = doc.data.get(&self.field) {
402 self.remove_document_internal(&doc._sid);
404
405 let terms = self.tokenize(text);
406 if terms.is_empty() {
407 return Ok(());
408 }
409
410 let doc_length = terms.len();
412 self.doc_lengths.insert(doc._sid.clone(), doc_length);
413 self.document_count.fetch_add(1, Ordering::Relaxed);
414 self.total_term_count
415 .fetch_add(doc_length as u64, Ordering::Relaxed);
416
417 let term_freq = self.calculate_term_frequencies(&terms);
419
420 for (term, freq) in term_freq {
422 self.index
423 .entry(term.clone())
424 .or_default()
425 .push((doc._sid.clone(), freq));
426 }
427
428 self.rebuild_fst();
430 }
431 Ok(())
432 }
433
434 pub fn search(&self, query: &str) -> Vec<(String, f32)> {
450 let query_terms = self.tokenize(query);
451 if query_terms.is_empty() {
452 return Vec::new();
453 }
454
455 let mut scores: HashMap<String, f32> = HashMap::new();
456 let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
457 let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
458
459 for term in query_terms {
460 let Some(term_entry) = self.index.get(&term) else {
461 continue;
462 };
463
464 let doc_freq = term_entry.len() as f32;
466 let idf = if doc_freq >= total_docs {
468 0.1
470 } else {
471 ((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
472 .ln()
473 .max(0.1)
474 };
475
476 for (doc_id, tf) in term_entry.iter() {
478 let doc_len = self
479 .doc_lengths
480 .get(doc_id)
481 .map(|entry| *entry.value() as f32)
482 .unwrap_or(avg_doc_len);
483
484 let tf_numerator = tf * (self.k1 + 1.0);
486 let tf_denominator =
487 tf + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
488 let tf_bm25 = tf_numerator / tf_denominator;
489
490 *scores.entry(doc_id.clone()).or_insert(0.0) += tf_bm25 * idf;
491 }
492 }
493
494 let mut results: Vec<_> = scores.into_iter().collect();
496 results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
497 results
498 }
499
500 pub fn fuzzy_search(&self, query: &str, max_distance: usize) -> Vec<(String, f32)> {
518 let query_terms = self.tokenize(query);
519 if query_terms.is_empty() {
520 return Vec::new();
521 }
522
523 let mut scores: HashMap<String, f32> = HashMap::new();
524 let total_docs = self.document_count.load(Ordering::Relaxed).max(1) as f32;
525 let avg_doc_len = self.total_term_count.load(Ordering::Relaxed) as f32 / total_docs;
526
527 if let Ok(fst_guard) = self.term_fst.read()
529 && let Some(ref fst) = *fst_guard
530 {
531 for query_term in query_terms {
532 if let Ok(lev) = Levenshtein::new(&query_term, max_distance as u32) {
534 let mut stream = fst.search(lev).into_stream();
535
536 while let Some(term_bytes) = stream.next() {
538 let term = String::from_utf8_lossy(term_bytes);
539
540 if let Some(term_entry) = self.index.get(term.as_ref()) {
541 let distance = self.levenshtein_distance(&query_term, &term) as f32;
543 let distance_penalty = 1.0 / (1.0 + distance * 0.3); let doc_freq = term_entry.len() as f32;
547 let idf = if doc_freq >= total_docs {
548 0.1
550 } else {
551 ((total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
552 .ln()
553 .max(0.1)
554 };
555
556 for (doc_id, tf) in term_entry.iter() {
557 let doc_len = self
558 .doc_lengths
559 .get(doc_id)
560 .map(|entry| *entry.value() as f32)
561 .unwrap_or(avg_doc_len);
562
563 let tf_numerator = tf * (self.k1 + 1.0);
564 let tf_denominator = tf
565 + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
566 let tf_bm25 = tf_numerator / tf_denominator;
567
568 let score = tf_bm25 * idf * distance_penalty;
569 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
570 }
571 }
572 }
573 }
574 }
575 }
576
577 let mut results: Vec<_> = scores.into_iter().collect();
579 results.sort_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap());
580 results
581 }
582
583 pub fn remove_document(&self, doc_id: &str) {
596 self.remove_document_internal(doc_id);
597 self.rebuild_fst();
598 }
599
600 fn remove_document_internal(&self, doc_id: &str) {
602 if let Some((_, doc_length)) = self.doc_lengths.remove(doc_id) {
604 self.document_count.fetch_sub(1, Ordering::Relaxed);
605 self.total_term_count
606 .fetch_sub(doc_length as u64, Ordering::Relaxed);
607 }
608
609 let mut empty_terms = Vec::new();
613
614 for mut entry in self.index.iter_mut() {
615 let term = entry.key().clone();
616 let doc_entries = entry.value_mut();
617
618 doc_entries.retain(|(id, _)| id != doc_id);
619
620 if doc_entries.is_empty() {
621 empty_terms.push(term);
622 }
623 }
624
625 for term in empty_terms {
627 self.index.remove(&term);
628 }
629 }
630
631 fn rebuild_fst(&self) {
633 let terms: Vec<String> = self.index.iter().map(|entry| entry.key().clone()).collect();
634
635 if !terms.is_empty() {
636 let mut sorted_terms = terms;
637 sorted_terms.sort();
638
639 if let Ok(fst) = fst::Set::from_iter(sorted_terms)
640 && let Ok(mut fst_guard) = self.term_fst.write()
641 {
642 *fst_guard = Some(fst);
643 }
644 }
645 }
646
647 pub fn highlight_matches(&self, text: &str, query: &str) -> String {
665 let query_terms: Vec<String> = self
666 .tokenize(query)
667 .into_iter()
668 .map(|t| regex::escape(&t))
669 .collect();
670
671 if query_terms.is_empty() {
672 return text.to_string();
673 }
674
675 let pattern = format!("({})", query_terms.join("|"));
677
678 match regex::RegexBuilder::new(&pattern)
679 .case_insensitive(true)
680 .build()
681 {
682 Ok(re) => re.replace_all(text, "<mark>$0</mark>").to_string(),
683 Err(_) => text.to_string(),
684 }
685 }
686
687 fn tokenize(&self, text: &str) -> Vec<String> {
689 let mut tokens: Vec<String> = text
690 .unicode_words()
691 .map(|word| word.to_lowercase())
692 .filter(|word| word.len() > 1) .collect();
694
695 if self.enable_stop_words {
697 tokens.retain(|token| !self.stop_words.contains(token));
698 }
699
700 tokens
701 }
702
703 fn calculate_term_frequencies(&self, terms: &[String]) -> HashMap<String, f32> {
705 let mut freq = HashMap::new();
706 for term in terms {
707 *freq.entry(term.clone()).or_insert(0.0) += 1.0;
708 }
709 freq
710 }
711
712 fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
714 let s1_chars: Vec<char> = s1.chars().collect();
715 let s2_chars: Vec<char> = s2.chars().collect();
716 let s1_len = s1_chars.len();
717 let s2_len = s2_chars.len();
718
719 if s1_len == 0 {
720 return s2_len;
721 }
722 if s2_len == 0 {
723 return s1_len;
724 }
725
726 let mut prev_row = (0..=s2_len).collect::<Vec<_>>();
727 let mut curr_row = vec![0; s2_len + 1];
728
729 for i in 1..=s1_len {
730 curr_row[0] = i;
731
732 for j in 1..=s2_len {
733 let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
734 0
735 } else {
736 1
737 };
738 curr_row[j] = std::cmp::min(
739 curr_row[j - 1] + 1, std::cmp::min(
741 prev_row[j] + 1, prev_row[j - 1] + cost, ),
744 );
745 }
746
747 std::mem::swap(&mut prev_row, &mut curr_row);
748 }
749
750 prev_row[s2_len]
751 }
752
753 pub fn set_stop_words_filter(&mut self, enable: bool) {
755 self.enable_stop_words = enable;
756 }
757
758 pub fn add_stopwords(&mut self, words: &[&str]) {
760 for word in words {
761 self.stop_words.insert(word.to_lowercase());
762 }
763 }
764
765 pub fn clear_stopwords(&mut self) {
767 self.stop_words.clear();
768 }
769
770 pub fn get_stopwords(&self) -> Vec<String> {
772 self.stop_words.iter().cloned().collect()
773 }
774
775 pub fn get_stats(&self) -> IndexStats {
777 IndexStats {
778 total_documents: self.document_count.load(Ordering::Relaxed),
779 total_terms: self.index.len(),
780 total_term_instances: self.total_term_count.load(Ordering::Relaxed),
781 average_document_length: if self.document_count.load(Ordering::Relaxed) > 0 {
782 self.total_term_count.load(Ordering::Relaxed) as f32
783 / self.document_count.load(Ordering::Relaxed) as f32
784 } else {
785 0.0
786 },
787 }
788 }
789
790 pub fn get_fuzzy_matches(&self, query: &str, max_distance: u32) -> Vec<(String, f32)> {
792 self.fuzzy_search(query, max_distance as usize)
793 }
794}
795
796#[derive(Debug, Clone)]
798pub struct IndexStats {
799 pub total_documents: usize,
800 pub total_terms: usize,
801 pub total_term_instances: u64,
802 pub average_document_length: f32,
803}
804
805#[cfg(test)]
806mod tests {
807 use super::*;
808 use std::collections::HashMap;
809
810 #[test]
811 fn test_enhanced_search() -> Result<()> {
812 let index = FullTextIndex::new("test_collection", "content");
813
814 let doc1 = Document {
816 _sid: "1".to_string(),
817 data: {
818 let mut map = HashMap::new();
819 map.insert(
820 "content".to_string(),
821 Value::String("wireless bluetooth headphones".to_string()),
822 );
823 map
824 },
825 };
826
827 let doc2 = Document {
828 _sid: "2".to_string(),
829 data: {
830 let mut map = HashMap::new();
831 map.insert("content".to_string(), Value::String(
832 "wireless bluetooth headphones with excellent sound quality and noise cancellation technology for music lovers".to_string()
833 ));
834 map
835 },
836 };
837
838 index.index_document(&doc1)?;
839 index.index_document(&doc2)?;
840
841 let results = index.search("wireless");
843 assert_eq!(results.len(), 2);
844
845 let doc1_score = results.iter().find(|(id, _)| id == "1").unwrap().1;
847 let doc2_score = results.iter().find(|(id, _)| id == "2").unwrap().1;
848
849 assert!(
850 doc1_score > doc2_score,
851 "Shorter document should have higher BM25 score: {} vs {}",
852 doc1_score,
853 doc2_score
854 );
855
856 Ok(())
857 }
858
859 #[test]
860 fn test_fuzzy_search() -> Result<()> {
861 let index = FullTextIndex::new("test", "content");
862
863 let doc = Document {
864 _sid: "1".to_string(),
865 data: {
866 let mut map = HashMap::new();
867 map.insert(
868 "content".to_string(),
869 Value::String("wireless bluetooth technology".to_string()),
870 );
871 map
872 },
873 };
874
875 index.index_document(&doc)?;
876
877 let results = index.fuzzy_search("wireles", 2);
879 assert!(!results.is_empty(), "Should find fuzzy matches");
880
881 let results = index.fuzzy_search("bluetoth", 2);
882 assert!(!results.is_empty(), "Should find fuzzy matches");
883
884 Ok(())
885 }
886
887 #[test]
888 fn test_stop_words() -> Result<()> {
889 let index = FullTextIndex::new("test", "content");
890
891 let doc = Document {
892 _sid: "1".to_string(),
893 data: {
894 let mut map = HashMap::new();
895 map.insert(
896 "content".to_string(),
897 Value::String("the quick brown fox".to_string()),
898 );
899 map
900 },
901 };
902
903 index.index_document(&doc)?;
904
905 let results = index.search("the");
907 assert!(results.is_empty(), "Stop words should be filtered");
908
909 let results = index.search("quick");
910 assert!(!results.is_empty(), "Non-stop words should be found");
911
912 Ok(())
913 }
914
915 #[test]
916 fn test_highlight_matches() -> Result<()> {
917 let index = FullTextIndex::new("test", "content");
918
919 let text = "This is a wireless bluetooth device";
920 let highlighted = index.highlight_matches(text, "wireless bluetooth");
921
922 assert!(highlighted.contains("<mark>wireless</mark>"));
923 assert!(highlighted.contains("<mark>bluetooth</mark>"));
924
925 Ok(())
926 }
927
928 #[test]
929 fn test_document_removal() -> Result<()> {
930 let index = FullTextIndex::new("test", "content");
931
932 let doc1 = Document {
933 _sid: "1".to_string(),
934 data: {
935 let mut map = HashMap::new();
936 map.insert(
937 "content".to_string(),
938 Value::String("wireless technology".to_string()),
939 );
940 map
941 },
942 };
943
944 let doc2 = Document {
945 _sid: "2".to_string(),
946 data: {
947 let mut map = HashMap::new();
948 map.insert(
949 "content".to_string(),
950 Value::String("bluetooth technology".to_string()),
951 );
952 map
953 },
954 };
955
956 index.index_document(&doc1)?;
957 index.index_document(&doc2)?;
958
959 let results = index.search("technology");
961 assert_eq!(results.len(), 2);
962
963 index.remove_document("1");
965
966 let results = index.search("technology");
968 assert_eq!(results.len(), 1);
969 assert_eq!(results[0].0, "2");
970
971 Ok(())
972 }
973
974 #[test]
975 fn test_levenshtein_distance() -> Result<()> {
976 let index = FullTextIndex::new("test", "content");
977
978 assert_eq!(index.levenshtein_distance("kitten", "sitting"), 3);
979 assert_eq!(index.levenshtein_distance("saturday", "sunday"), 3);
980 assert_eq!(index.levenshtein_distance("", "abc"), 3);
981 assert_eq!(index.levenshtein_distance("abc", ""), 3);
982 assert_eq!(index.levenshtein_distance("same", "same"), 0);
983
984 Ok(())
985 }
986}