1use std::collections::HashMap;
25
26use nodedb_types::{Surrogate, SurrogateBitmap};
27
28use crate::backend::FtsBackend;
29use crate::bm25::bm25_score;
30use crate::index::FtsIndex;
31use crate::index::error::FtsIndexError;
32use crate::posting::{Posting, QueryMode, TextSearchResult};
33use crate::search::phrase;
34use crate::search::query_parser::parse_query;
35
36impl<B: FtsBackend> FtsIndex<B> {
37 pub fn search(
43 &self,
44 tid: u64,
45 collection: &str,
46 query: &str,
47 top_k: usize,
48 fuzzy_enabled: bool,
49 prefilter: Option<&SurrogateBitmap>,
50 ) -> Result<Vec<TextSearchResult>, FtsIndexError<B::Error>> {
51 self.search_with_mode(
52 tid,
53 collection,
54 query,
55 top_k,
56 fuzzy_enabled,
57 QueryMode::And,
58 prefilter,
59 )
60 }
61
62 #[allow(clippy::too_many_arguments)]
66 pub fn search_with_mode(
67 &self,
68 tid: u64,
69 collection: &str,
70 query: &str,
71 top_k: usize,
72 fuzzy_enabled: bool,
73 mode: QueryMode,
74 prefilter: Option<&SurrogateBitmap>,
75 ) -> Result<Vec<TextSearchResult>, FtsIndexError<B::Error>> {
76 let parsed = parse_query(query)?;
78
79 let positive_raw = parsed.positive.join(" ");
84 let negative_raw_terms = parsed.negative;
85
86 let base_tokens = self
87 .analyze_for_collection(tid, collection, &positive_raw)
88 .map_err(FtsIndexError::backend)?;
89 if base_tokens.is_empty() {
90 return Ok(Vec::new());
91 }
92
93 let base_token_count = base_tokens.len();
94 let query_tokens = self
95 .expand_query_with_synonyms(tid, base_tokens)
96 .map_err(FtsIndexError::backend)?;
97 let num_query_terms = query_tokens.len();
98 let and_threshold = base_token_count;
99
100 let raw_tokens = if fuzzy_enabled {
101 self.tokenize_raw_for_collection(tid, collection, &positive_raw)
102 .map_err(FtsIndexError::backend)?
103 } else {
104 Vec::new()
105 };
106
107 let (total_docs, avg_doc_len) = self
108 .index_stats(tid, collection)
109 .map_err(FtsIndexError::backend)?;
110 if total_docs == 0 {
111 return Ok(Vec::new());
112 }
113
114 let negative_set = self.build_negative_set(tid, collection, &negative_raw_terms)?;
118
119 let bmw_params = super::bmw::query::BmwParams {
120 query_tokens: &query_tokens,
121 raw_tokens: &raw_tokens,
122 fuzzy_enabled,
123 top_k: if mode == QueryMode::And && and_threshold > 1 {
124 top_k.saturating_mul(3).max(20)
125 } else {
126 top_k
127 },
128 total_docs,
129 avg_doc_len,
130 bm25: &self.bm25_params,
131 prefilter,
132 };
133 if let Ok(Some(bmw_results)) =
134 super::bmw::query::bmw_search(self, tid, collection, &bmw_params)
135 {
136 if mode == QueryMode::Or || and_threshold == 1 {
137 let mut results: Vec<TextSearchResult> = bmw_results
138 .into_iter()
139 .filter(|r| !negative_set.contains(&r.doc_id))
140 .take(top_k)
141 .collect();
142 results.truncate(top_k);
143 return Ok(results);
144 }
145
146 let and_results = self
147 .filter_and_mode(tid, collection, &query_tokens, &bmw_results, and_threshold)
148 .map_err(FtsIndexError::backend)?;
149
150 if !and_results.is_empty() {
151 let filtered: Vec<TextSearchResult> = and_results
152 .into_iter()
153 .filter(|r| !negative_set.contains(&r.doc_id))
154 .take(top_k)
155 .collect();
156 return Ok(filtered);
157 }
158
159 let penalized: Vec<TextSearchResult> = bmw_results
160 .into_iter()
161 .filter(|r| !negative_set.contains(&r.doc_id))
162 .map(|mut r| {
163 let matched = self.count_term_matches(tid, collection, &query_tokens, r.doc_id);
164 let coverage = matched as f32 / and_threshold as f32;
165 r.score *= coverage;
166 r
167 })
168 .collect();
169 let mut sorted = penalized;
170 sorted.sort_by(|a, b| {
171 b.score
172 .partial_cmp(&a.score)
173 .unwrap_or(std::cmp::Ordering::Equal)
174 });
175 sorted.truncate(top_k);
176 return Ok(sorted);
177 }
178
179 let _term_postings_guard = self.governor.as_ref().and_then(|gov| {
181 let bytes = num_query_terms
182 * (std::mem::size_of::<Vec<Posting>>() + std::mem::size_of::<bool>());
183 gov.reserve(nodedb_mem::EngineId::Fts, bytes).ok()
184 });
185 let mut term_postings: Vec<(Vec<Posting>, bool)> = Vec::with_capacity(num_query_terms);
186 for (i, token) in query_tokens.iter().enumerate() {
187 let postings = self
188 .backend
189 .read_postings(tid, collection, token)
190 .map_err(FtsIndexError::backend)?;
191 if !postings.is_empty() {
192 term_postings.push((postings, false));
193 } else if fuzzy_enabled {
194 let raw = raw_tokens
195 .get(i)
196 .map(String::as_str)
197 .unwrap_or(token.as_str());
198 let (fuzzy_posts, is_fuzzy) = self
199 .fuzzy_lookup(tid, collection, raw)
200 .map_err(FtsIndexError::backend)?;
201 term_postings.push((fuzzy_posts, is_fuzzy));
202 } else {
203 term_postings.push((Vec::new(), false));
204 }
205 }
206
207 let mut doc_scores: HashMap<Surrogate, (f32, bool, usize)> = HashMap::new();
208
209 for (token_idx, (postings, is_fuzzy)) in term_postings.iter().enumerate() {
210 if postings.is_empty() {
211 continue;
212 }
213 let df = postings.len() as u32;
214
215 for posting in postings {
216 if let Some(bm) = prefilter
218 && !bm.contains(posting.doc_id)
219 {
220 continue;
221 }
222
223 let doc_len = self
224 .backend
225 .read_doc_length(tid, collection, posting.doc_id)
226 .map_err(FtsIndexError::backend)?
227 .unwrap_or(1);
228
229 let mut score = bm25_score(
230 posting.term_freq,
231 df,
232 doc_len,
233 total_docs,
234 avg_doc_len,
235 &self.bm25_params,
236 );
237
238 if *is_fuzzy {
239 score *= crate::fuzzy::fuzzy_discount(1);
240 }
241
242 let entry = doc_scores.entry(posting.doc_id).or_insert((0.0, false, 0));
243 entry.0 += score;
244 if *is_fuzzy {
245 entry.1 = true;
246 }
247 entry.2 += 1;
248 }
249 let _ = token_idx;
250 }
251
252 if num_query_terms >= 2 {
253 let doc_postings_map = phrase::collect_doc_postings(&query_tokens, &term_postings);
254 for (doc_id, token_postings) in &doc_postings_map {
255 if let Some(entry) = doc_scores.get_mut(doc_id) {
256 let boost = phrase::phrase_boost(&query_tokens, token_postings);
257 entry.0 *= boost;
258 }
259 }
260 }
261
262 if mode == QueryMode::And && and_threshold > 1 {
263 let and_results: HashMap<Surrogate, (f32, bool, usize)> = doc_scores
264 .iter()
265 .filter(|(_, (_, _, match_count))| *match_count >= and_threshold)
266 .map(|(k, v)| (*k, *v))
267 .collect();
268
269 if !and_results.is_empty() {
270 let filtered = and_results
271 .into_iter()
272 .filter(|(doc_id, _)| !negative_set.contains(doc_id))
273 .collect();
274 return Ok(Self::to_sorted_results(filtered, top_k));
275 }
276
277 for (score, _, match_count) in doc_scores.values_mut() {
278 let coverage = *match_count as f32 / and_threshold as f32;
279 *score *= coverage;
280 }
281 }
282
283 let filtered: HashMap<Surrogate, (f32, bool, usize)> = doc_scores
285 .into_iter()
286 .filter(|(doc_id, _)| !negative_set.contains(doc_id))
287 .collect();
288
289 Ok(Self::to_sorted_results(filtered, top_k))
290 }
291
292 fn build_negative_set(
297 &self,
298 tid: u64,
299 collection: &str,
300 raw_negative_terms: &[String],
301 ) -> Result<std::collections::HashSet<Surrogate>, FtsIndexError<B::Error>> {
302 if raw_negative_terms.is_empty() {
303 return Ok(std::collections::HashSet::new());
304 }
305
306 let neg_raw = raw_negative_terms.join(" ");
308 let neg_base_tokens = self
309 .analyze_for_collection(tid, collection, &neg_raw)
310 .map_err(FtsIndexError::backend)?;
311
312 if neg_base_tokens.is_empty() {
313 return Ok(std::collections::HashSet::new());
314 }
315
316 let neg_tokens = self
318 .expand_query_with_synonyms(tid, neg_base_tokens)
319 .map_err(FtsIndexError::backend)?;
320
321 let mut excluded: std::collections::HashSet<Surrogate> = std::collections::HashSet::new();
322
323 let term_blocks = crate::lsm::query::collect_merged_term_blocks(
325 &self.backend,
326 tid,
327 collection,
328 self.memtable(),
329 &neg_tokens,
330 self.governor.as_ref(),
331 )
332 .map_err(FtsIndexError::backend)?;
333
334 for tb in &term_blocks {
335 for block in &tb.blocks {
336 for doc_id in &block.doc_ids {
337 excluded.insert(*doc_id);
338 }
339 }
340 }
341
342 for token in &neg_tokens {
344 let postings = self
345 .backend
346 .read_postings(tid, collection, token)
347 .map_err(FtsIndexError::backend)?;
348 for posting in postings {
349 excluded.insert(posting.doc_id);
350 }
351 }
352
353 Ok(excluded)
354 }
355
356 fn filter_and_mode(
357 &self,
358 tid: u64,
359 collection: &str,
360 query_tokens: &[String],
361 candidates: &[TextSearchResult],
362 num_terms: usize,
363 ) -> Result<Vec<TextSearchResult>, B::Error> {
364 let term_blocks = crate::lsm::query::collect_merged_term_blocks(
365 &self.backend,
366 tid,
367 collection,
368 self.memtable(),
369 query_tokens,
370 self.governor.as_ref(),
371 )?;
372
373 let mut results = Vec::new();
374 for candidate in candidates {
375 let surrogate = candidate.doc_id;
376 let matched = term_blocks
377 .iter()
378 .filter(|tb| tb.blocks.iter().any(|b| b.doc_ids.contains(&surrogate)))
379 .count();
380 if matched >= num_terms {
381 results.push(candidate.clone());
382 }
383 }
384 Ok(results)
385 }
386
387 fn count_term_matches(
388 &self,
389 tid: u64,
390 collection: &str,
391 query_tokens: &[String],
392 doc_id: Surrogate,
393 ) -> usize {
394 let term_blocks = match crate::lsm::query::collect_merged_term_blocks(
395 &self.backend,
396 tid,
397 collection,
398 self.memtable(),
399 query_tokens,
400 self.governor.as_ref(),
401 ) {
402 Ok(tb) => tb,
403 Err(_) => return 0,
404 };
405 term_blocks
406 .iter()
407 .filter(|tb| tb.blocks.iter().any(|b| b.doc_ids.contains(&doc_id)))
408 .count()
409 }
410
411 fn to_sorted_results(
412 doc_scores: HashMap<Surrogate, (f32, bool, usize)>,
413 top_k: usize,
414 ) -> Vec<TextSearchResult> {
415 let mut results: Vec<TextSearchResult> = doc_scores
416 .into_iter()
417 .map(|(doc_id, (score, fuzzy_flag, _))| TextSearchResult {
418 doc_id,
419 score,
420 fuzzy: fuzzy_flag,
421 })
422 .collect();
423 results.sort_by(|a, b| {
424 b.score
425 .partial_cmp(&a.score)
426 .unwrap_or(std::cmp::Ordering::Equal)
427 });
428 results.truncate(top_k);
429 results
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use nodedb_types::{Surrogate, SurrogateBitmap};
436
437 use crate::backend::memory::MemoryBackend;
438 use crate::index::FtsIndex;
439 use crate::index::error::FtsIndexError;
440 use crate::posting::QueryMode;
441 use crate::search::query_parser::InvalidQuery;
442
443 const T: u64 = 1;
444 const D1: Surrogate = Surrogate(1);
445 const D2: Surrogate = Surrogate(2);
446 const D3: Surrogate = Surrogate(3);
447
448 fn make_index() -> FtsIndex<MemoryBackend> {
449 let idx = FtsIndex::new(MemoryBackend::new());
450 idx.index_document(T, "docs", D1, "The quick brown fox jumps over the lazy dog")
451 .unwrap();
452 idx.index_document(T, "docs", D2, "A fast brown dog runs across the field")
453 .unwrap();
454 idx.index_document(T, "docs", D3, "Rust programming language for systems")
455 .unwrap();
456 idx
457 }
458
459 #[test]
460 fn basic_search() {
461 let idx = make_index();
462 let results = idx.search(T, "docs", "brown fox", 10, false, None).unwrap();
463 assert!(!results.is_empty());
464 assert_eq!(results[0].doc_id, D1);
465 }
466
467 #[test]
468 fn search_with_stemming() {
469 let idx = FtsIndex::new(MemoryBackend::new());
470 idx.index_document(T, "docs", D1, "running distributed databases")
471 .unwrap();
472 idx.index_document(T, "docs", D2, "the cat sat on a mat")
473 .unwrap();
474
475 let results = idx
476 .search(T, "docs", "database distribution", 10, false, None)
477 .unwrap();
478 assert!(!results.is_empty());
479 assert_eq!(results[0].doc_id, D1);
480 }
481
482 #[test]
483 fn or_mode() {
484 let idx = make_index();
485 let results = idx
486 .search_with_mode(T, "docs", "brown fox", 10, false, QueryMode::Or, None)
487 .unwrap();
488 assert!(results.len() >= 2);
489 }
490
491 #[test]
492 fn and_mode_filters() {
493 let idx = FtsIndex::new(MemoryBackend::new());
494 idx.index_document(T, "docs", D1, "Rust programming language")
495 .unwrap();
496 idx.index_document(T, "docs", D2, "Python programming language")
497 .unwrap();
498
499 let results = idx
500 .search_with_mode(
501 T,
502 "docs",
503 "rust programming",
504 10,
505 false,
506 QueryMode::And,
507 None,
508 )
509 .unwrap();
510 assert_eq!(results.len(), 1);
511 assert_eq!(results[0].doc_id, D1);
512 }
513
514 #[test]
515 fn and_fallback_to_or() {
516 let idx = FtsIndex::new(MemoryBackend::new());
517 idx.index_document(T, "docs", D1, "rust programming language")
518 .unwrap();
519 idx.index_document(T, "docs", D2, "python programming language")
520 .unwrap();
521
522 let results = idx
523 .search(T, "docs", "rust python", 10, false, None)
524 .unwrap();
525 assert_eq!(results.len(), 2);
526 for r in &results {
527 assert!(r.score > 0.0);
528 }
529 }
530
531 #[test]
532 fn and_no_fallback_when_results_exist() {
533 let idx = FtsIndex::new(MemoryBackend::new());
534 idx.index_document(T, "docs", D1, "rust programming language")
535 .unwrap();
536 idx.index_document(T, "docs", D2, "python programming language")
537 .unwrap();
538
539 let results = idx
540 .search(T, "docs", "rust programming", 10, false, None)
541 .unwrap();
542 assert_eq!(results.len(), 1);
543 assert_eq!(results[0].doc_id, D1);
544 }
545
546 #[test]
547 fn empty_query() {
548 let idx = make_index();
549 let results = idx.search(T, "docs", "the a is", 10, false, None).unwrap();
550 assert!(results.is_empty());
551 }
552
553 #[test]
554 fn collections_isolated() {
555 let idx = FtsIndex::new(MemoryBackend::new());
556 idx.index_document(T, "col_a", D1, "alpha bravo charlie")
557 .unwrap();
558 idx.index_document(T, "col_b", D1, "delta echo foxtrot")
559 .unwrap();
560
561 assert_eq!(
562 idx.search(T, "col_a", "alpha", 10, false, None)
563 .unwrap()
564 .len(),
565 1
566 );
567 assert!(
568 idx.search(T, "col_b", "alpha", 10, false, None)
569 .unwrap()
570 .is_empty()
571 );
572 }
573
574 #[test]
575 fn fuzzy_search() {
576 let idx = FtsIndex::new(MemoryBackend::new());
577 idx.index_document(T, "docs", D1, "distributed database systems")
578 .unwrap();
579
580 let results = idx.search(T, "docs", "databse", 10, true, None).unwrap();
581 assert!(!results.is_empty());
582 assert!(results[0].fuzzy);
583 }
584
585 #[test]
586 fn phrase_boost_consecutive() {
587 let idx = FtsIndex::new(MemoryBackend::new());
588 idx.index_document(T, "docs", D1, "the quick brown fox jumped")
589 .unwrap();
590 idx.index_document(T, "docs", D2, "a brown dog chased a fox")
591 .unwrap();
592
593 let results = idx
594 .search_with_mode(T, "docs", "brown fox", 10, false, QueryMode::Or, None)
595 .unwrap();
596 assert!(results.len() >= 2);
597 assert_eq!(results[0].doc_id, D1);
598 }
599
600 #[test]
601 fn phrase_boost_no_effect_single_term() {
602 let idx = FtsIndex::new(MemoryBackend::new());
603 idx.index_document(T, "docs", D1, "hello world").unwrap();
604
605 let results = idx.search(T, "docs", "hello", 10, false, None).unwrap();
606 assert_eq!(results.len(), 1);
607 }
608
609 #[test]
610 fn tenants_isolated() {
611 let idx = FtsIndex::new(MemoryBackend::new());
612 idx.index_document(1, "docs", D1, "alpha bravo").unwrap();
613 idx.index_document(2, "docs", D1, "charlie delta").unwrap();
614
615 let r1 = idx.search(1, "docs", "alpha", 10, false, None).unwrap();
616 let r2 = idx.search(2, "docs", "alpha", 10, false, None).unwrap();
617 assert_eq!(r1.len(), 1);
618 assert!(r2.is_empty());
619 }
620
621 #[test]
622 fn prefilter_excludes_non_member_surrogates() {
623 let idx = FtsIndex::new(MemoryBackend::new());
624
625 idx.index_document(T, "docs", D1, "rust language system")
626 .unwrap();
627 idx.index_document(T, "docs", D2, "rust rust rust rust rust")
628 .unwrap();
629 idx.index_document(T, "docs", D3, "rust rust rust rust rust rust")
630 .unwrap();
631
632 let mut bm = SurrogateBitmap::new();
633 bm.insert(D1);
634
635 let results = idx.search(T, "docs", "rust", 10, false, Some(&bm)).unwrap();
636
637 assert_eq!(results.len(), 1, "only D1 should be returned");
638 assert_eq!(results[0].doc_id, D1);
639
640 assert!(
641 !results.iter().any(|r| r.doc_id == D2),
642 "D2 must be excluded"
643 );
644 assert!(
645 !results.iter().any(|r| r.doc_id == D3),
646 "D3 must be excluded"
647 );
648
649 let all_results = idx.search(T, "docs", "rust", 10, false, None).unwrap();
650 assert_eq!(all_results.len(), 3, "all docs returned without prefilter");
651 assert!(
652 all_results[0].doc_id == D2 || all_results[0].doc_id == D3,
653 "D2 or D3 should lead without prefilter (higher tf)"
654 );
655
656 let empty_bm = SurrogateBitmap::new();
657 let empty_results = idx
658 .search(T, "docs", "rust", 10, false, Some(&empty_bm))
659 .unwrap();
660 assert!(empty_results.is_empty(), "empty prefilter → no results");
661
662 let mut bm23 = SurrogateBitmap::new();
663 bm23.insert(D2);
664 bm23.insert(D3);
665 let results23 = idx
666 .search(T, "docs", "rust", 10, false, Some(&bm23))
667 .unwrap();
668 assert_eq!(results23.len(), 2);
669 assert!(!results23.iter().any(|r| r.doc_id == D1));
670 }
671
672 #[test]
675 fn not_keyword_excludes_documents() {
676 let idx = FtsIndex::new(MemoryBackend::new());
677 idx.index_document(T, "docs", D1, "rust python programming")
679 .unwrap();
680 idx.index_document(T, "docs", D2, "rust ruby programming")
681 .unwrap();
682 idx.index_document(T, "docs", D3, "python ruby programming")
683 .unwrap();
684
685 let results = idx
686 .search(T, "docs", "rust NOT python", 10, false, None)
687 .unwrap();
688 assert!(
690 results.iter().any(|r| r.doc_id == D2),
691 "D2 (rust+ruby) must be in results"
692 );
693 assert!(
694 !results.iter().any(|r| r.doc_id == D1),
695 "D1 (rust+python) must be excluded"
696 );
697 }
698
699 #[test]
700 fn dash_prefix_excludes_documents() {
701 let idx = FtsIndex::new(MemoryBackend::new());
702 idx.index_document(T, "docs", D1, "rust python programming")
703 .unwrap();
704 idx.index_document(T, "docs", D2, "rust ruby programming")
705 .unwrap();
706 idx.index_document(T, "docs", D3, "python ruby programming")
707 .unwrap();
708
709 let results = idx
710 .search(T, "docs", "rust -python", 10, false, None)
711 .unwrap();
712 assert!(results.iter().any(|r| r.doc_id == D2));
713 assert!(!results.iter().any(|r| r.doc_id == D1));
714 }
715
716 #[test]
717 fn multiple_not_excludes_all_negated() {
718 let idx = FtsIndex::new(MemoryBackend::new());
719 idx.index_document(T, "docs", D1, "rust python programming")
720 .unwrap();
721 idx.index_document(T, "docs", D2, "rust ruby programming")
722 .unwrap();
723 idx.index_document(T, "docs", D3, "rust systems programming")
724 .unwrap();
725
726 let results = idx
727 .search(T, "docs", "rust NOT python NOT ruby", 10, false, None)
728 .unwrap();
729 assert!(results.iter().any(|r| r.doc_id == D3));
731 assert!(!results.iter().any(|r| r.doc_id == D1));
732 assert!(!results.iter().any(|r| r.doc_id == D2));
733 }
734
735 #[test]
736 fn not_nonexistent_term_returns_all_positives() {
737 let idx = FtsIndex::new(MemoryBackend::new());
738 idx.index_document(T, "docs", D1, "rust programming")
739 .unwrap();
740 idx.index_document(T, "docs", D2, "rust systems").unwrap();
741
742 let results_plain = idx.search(T, "docs", "rust", 10, false, None).unwrap();
743 let results_not = idx
744 .search(T, "docs", "rust NOT nonexistentxyz", 10, false, None)
745 .unwrap();
746
747 let plain_ids: std::collections::HashSet<Surrogate> =
748 results_plain.iter().map(|r| r.doc_id).collect();
749 let not_ids: std::collections::HashSet<Surrogate> =
750 results_not.iter().map(|r| r.doc_id).collect();
751 assert_eq!(
752 plain_ids, not_ids,
753 "NOT with nonexistent term must not remove any docs"
754 );
755 }
756
757 #[test]
758 fn negative_only_returns_invalid_query_error() {
759 let idx = FtsIndex::new(MemoryBackend::new());
760 idx.index_document(T, "docs", D1, "python programming")
761 .unwrap();
762
763 let err = idx
764 .search(T, "docs", "NOT python", 10, false, None)
765 .unwrap_err();
766 assert!(
767 matches!(err, FtsIndexError::InvalidQuery(InvalidQuery::NegativeOnly)),
768 "expected InvalidQuery(NegativeOnly), got {err:?}"
769 );
770 }
771
772 #[test]
773 fn parentheses_after_not_returns_invalid_query_error() {
774 let idx = FtsIndex::new(MemoryBackend::new());
775 idx.index_document(T, "docs", D1, "rust programming")
776 .unwrap();
777
778 let err = idx
779 .search(T, "docs", "rust NOT (python OR ruby)", 10, false, None)
780 .unwrap_err();
781 assert!(
782 matches!(
783 err,
784 FtsIndexError::InvalidQuery(InvalidQuery::ParenthesesNotSupported)
785 ),
786 "expected InvalidQuery(ParenthesesNotSupported), got {err:?}"
787 );
788 }
789}