1use std::collections::HashMap;
4
5use crate::backend::FtsBackend;
6use crate::bm25::bm25_score;
7use crate::index::FtsIndex;
8use crate::posting::{Posting, QueryMode, TextSearchResult};
9use crate::search::phrase;
10
11impl<B: FtsBackend> FtsIndex<B> {
12 pub fn search(
14 &self,
15 tid: u32,
16 collection: &str,
17 query: &str,
18 top_k: usize,
19 fuzzy_enabled: bool,
20 ) -> Result<Vec<TextSearchResult>, B::Error> {
21 self.search_with_mode(tid, collection, query, top_k, fuzzy_enabled, QueryMode::And)
22 }
23
24 pub fn search_with_mode(
26 &self,
27 tid: u32,
28 collection: &str,
29 query: &str,
30 top_k: usize,
31 fuzzy_enabled: bool,
32 mode: QueryMode,
33 ) -> Result<Vec<TextSearchResult>, B::Error> {
34 let query_tokens = self.analyze_for_collection(tid, collection, query)?;
35 if query_tokens.is_empty() {
36 return Ok(Vec::new());
37 }
38 let num_query_terms = query_tokens.len();
39
40 let raw_tokens = if fuzzy_enabled {
41 self.tokenize_raw_for_collection(tid, collection, query)?
42 } else {
43 Vec::new()
44 };
45
46 let (total_docs, avg_doc_len) = self.index_stats(tid, collection)?;
47 if total_docs == 0 {
48 return Ok(Vec::new());
49 }
50
51 let bmw_params = super::bmw::query::BmwParams {
52 query_tokens: &query_tokens,
53 raw_tokens: &raw_tokens,
54 fuzzy_enabled,
55 top_k: if mode == QueryMode::And && num_query_terms > 1 {
56 top_k.saturating_mul(3).max(20)
57 } else {
58 top_k
59 },
60 total_docs,
61 avg_doc_len,
62 bm25: &self.bm25_params,
63 };
64 if let Ok(Some(bmw_results)) =
65 super::bmw::query::bmw_search(self, tid, collection, &bmw_params)
66 {
67 if mode == QueryMode::Or || num_query_terms == 1 {
68 return Ok(bmw_results.into_iter().take(top_k).collect());
69 }
70
71 let and_results = self.filter_and_mode(
72 tid,
73 collection,
74 &query_tokens,
75 &bmw_results,
76 num_query_terms,
77 )?;
78
79 if !and_results.is_empty() {
80 return Ok(and_results.into_iter().take(top_k).collect());
81 }
82
83 let penalized: Vec<TextSearchResult> = bmw_results
84 .into_iter()
85 .map(|mut r| {
86 let matched =
87 self.count_term_matches(tid, collection, &query_tokens, &r.doc_id);
88 let coverage = matched as f32 / num_query_terms as f32;
89 r.score *= coverage;
90 r
91 })
92 .collect();
93 let mut sorted = penalized;
94 sorted.sort_by(|a, b| {
95 b.score
96 .partial_cmp(&a.score)
97 .unwrap_or(std::cmp::Ordering::Equal)
98 });
99 sorted.truncate(top_k);
100 return Ok(sorted);
101 }
102
103 let mut term_postings: Vec<(Vec<Posting>, bool)> = Vec::with_capacity(num_query_terms);
105 for (i, token) in query_tokens.iter().enumerate() {
106 let postings = self.backend.read_postings(tid, collection, token)?;
107 if !postings.is_empty() {
108 term_postings.push((postings, false));
109 } else if fuzzy_enabled {
110 let raw = raw_tokens
111 .get(i)
112 .map(String::as_str)
113 .unwrap_or(token.as_str());
114 let (fuzzy_posts, is_fuzzy) = self.fuzzy_lookup(tid, collection, raw)?;
115 term_postings.push((fuzzy_posts, is_fuzzy));
116 } else {
117 term_postings.push((Vec::new(), false));
118 }
119 }
120
121 let mut doc_scores: HashMap<String, (f32, bool, usize)> = HashMap::new();
122
123 for (token_idx, (postings, is_fuzzy)) in term_postings.iter().enumerate() {
124 if postings.is_empty() {
125 continue;
126 }
127 let df = postings.len() as u32;
128
129 for posting in postings {
130 let doc_len = self
131 .backend
132 .read_doc_length(tid, collection, &posting.doc_id)?
133 .unwrap_or(1);
134
135 let mut score = bm25_score(
136 posting.term_freq,
137 df,
138 doc_len,
139 total_docs,
140 avg_doc_len,
141 &self.bm25_params,
142 );
143
144 if *is_fuzzy {
145 score *= crate::fuzzy::fuzzy_discount(1);
146 }
147
148 let entry = doc_scores
149 .entry(posting.doc_id.clone())
150 .or_insert((0.0, false, 0));
151 entry.0 += score;
152 if *is_fuzzy {
153 entry.1 = true;
154 }
155 entry.2 += 1;
156 }
157 let _ = token_idx;
158 }
159
160 if num_query_terms >= 2 {
161 let doc_postings_map = phrase::collect_doc_postings(&query_tokens, &term_postings);
162 for (doc_id, token_postings) in &doc_postings_map {
163 if let Some(entry) = doc_scores.get_mut(doc_id.as_str()) {
164 let boost = phrase::phrase_boost(&query_tokens, token_postings);
165 entry.0 *= boost;
166 }
167 }
168 }
169
170 if mode == QueryMode::And && num_query_terms > 1 {
171 let and_results: HashMap<String, (f32, bool, usize)> = doc_scores
172 .iter()
173 .filter(|(_, (_, _, match_count))| *match_count >= num_query_terms)
174 .map(|(k, v)| (k.clone(), *v))
175 .collect();
176
177 if !and_results.is_empty() {
178 return Ok(Self::to_sorted_results(and_results, top_k));
179 }
180
181 for (score, _, match_count) in doc_scores.values_mut() {
182 let coverage = *match_count as f32 / num_query_terms as f32;
183 *score *= coverage;
184 }
185 }
186
187 Ok(Self::to_sorted_results(doc_scores, top_k))
188 }
189
190 fn filter_and_mode(
191 &self,
192 tid: u32,
193 collection: &str,
194 query_tokens: &[String],
195 candidates: &[TextSearchResult],
196 num_terms: usize,
197 ) -> Result<Vec<TextSearchResult>, B::Error> {
198 let doc_map = self.load_doc_id_map(tid, collection)?;
199 let term_blocks = crate::lsm::query::collect_merged_term_blocks(
200 &self.backend,
201 tid,
202 collection,
203 self.memtable(),
204 query_tokens,
205 )?;
206
207 let mut results = Vec::new();
208 for candidate in candidates {
209 let int_id = doc_map.to_u32(&candidate.doc_id);
210 let matched = term_blocks
211 .iter()
212 .filter(|tb| {
213 int_id.is_some_and(|id| tb.blocks.iter().any(|b| b.doc_ids.contains(&id)))
214 })
215 .count();
216 if matched >= num_terms {
217 results.push(candidate.clone());
218 }
219 }
220 Ok(results)
221 }
222
223 fn count_term_matches(
224 &self,
225 tid: u32,
226 collection: &str,
227 query_tokens: &[String],
228 doc_id: &str,
229 ) -> usize {
230 let doc_map = match self.load_doc_id_map(tid, collection) {
231 Ok(m) => m,
232 Err(_) => return 0,
233 };
234 let Some(int_id) = doc_map.to_u32(doc_id) else {
235 return 0;
236 };
237 let term_blocks = match crate::lsm::query::collect_merged_term_blocks(
238 &self.backend,
239 tid,
240 collection,
241 self.memtable(),
242 query_tokens,
243 ) {
244 Ok(tb) => tb,
245 Err(_) => return 0,
246 };
247 term_blocks
248 .iter()
249 .filter(|tb| tb.blocks.iter().any(|b| b.doc_ids.contains(&int_id)))
250 .count()
251 }
252
253 fn to_sorted_results(
254 doc_scores: HashMap<String, (f32, bool, usize)>,
255 top_k: usize,
256 ) -> Vec<TextSearchResult> {
257 let mut results: Vec<TextSearchResult> = doc_scores
258 .into_iter()
259 .map(|(doc_id, (score, fuzzy_flag, _))| TextSearchResult {
260 doc_id,
261 score,
262 fuzzy: fuzzy_flag,
263 })
264 .collect();
265 results.sort_by(|a, b| {
266 b.score
267 .partial_cmp(&a.score)
268 .unwrap_or(std::cmp::Ordering::Equal)
269 });
270 results.truncate(top_k);
271 results
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use crate::backend::memory::MemoryBackend;
278 use crate::index::FtsIndex;
279 use crate::posting::QueryMode;
280
281 const T: u32 = 1;
282
283 fn make_index() -> FtsIndex<MemoryBackend> {
284 let idx = FtsIndex::new(MemoryBackend::new());
285 idx.index_document(
286 T,
287 "docs",
288 "d1",
289 "The quick brown fox jumps over the lazy dog",
290 )
291 .unwrap();
292 idx.index_document(T, "docs", "d2", "A fast brown dog runs across the field")
293 .unwrap();
294 idx.index_document(T, "docs", "d3", "Rust programming language for systems")
295 .unwrap();
296 idx
297 }
298
299 #[test]
300 fn basic_search() {
301 let idx = make_index();
302 let results = idx.search(T, "docs", "brown fox", 10, false).unwrap();
303 assert!(!results.is_empty());
304 assert_eq!(results[0].doc_id, "d1");
305 }
306
307 #[test]
308 fn search_with_stemming() {
309 let idx = FtsIndex::new(MemoryBackend::new());
310 idx.index_document(T, "docs", "d1", "running distributed databases")
311 .unwrap();
312 idx.index_document(T, "docs", "d2", "the cat sat on a mat")
313 .unwrap();
314
315 let results = idx
316 .search(T, "docs", "database distribution", 10, false)
317 .unwrap();
318 assert!(!results.is_empty());
319 assert_eq!(results[0].doc_id, "d1");
320 }
321
322 #[test]
323 fn or_mode() {
324 let idx = make_index();
325 let results = idx
326 .search_with_mode(T, "docs", "brown fox", 10, false, QueryMode::Or)
327 .unwrap();
328 assert!(results.len() >= 2);
329 }
330
331 #[test]
332 fn and_mode_filters() {
333 let idx = FtsIndex::new(MemoryBackend::new());
334 idx.index_document(T, "docs", "d1", "Rust programming language")
335 .unwrap();
336 idx.index_document(T, "docs", "d2", "Python programming language")
337 .unwrap();
338
339 let results = idx
340 .search_with_mode(T, "docs", "rust programming", 10, false, QueryMode::And)
341 .unwrap();
342 assert_eq!(results.len(), 1);
343 assert_eq!(results[0].doc_id, "d1");
344 }
345
346 #[test]
347 fn and_fallback_to_or() {
348 let idx = FtsIndex::new(MemoryBackend::new());
349 idx.index_document(T, "docs", "d1", "rust programming language")
350 .unwrap();
351 idx.index_document(T, "docs", "d2", "python programming language")
352 .unwrap();
353
354 let results = idx.search(T, "docs", "rust python", 10, false).unwrap();
355 assert_eq!(results.len(), 2);
356 for r in &results {
357 assert!(r.score > 0.0);
358 }
359 }
360
361 #[test]
362 fn and_no_fallback_when_results_exist() {
363 let idx = FtsIndex::new(MemoryBackend::new());
364 idx.index_document(T, "docs", "d1", "rust programming language")
365 .unwrap();
366 idx.index_document(T, "docs", "d2", "python programming language")
367 .unwrap();
368
369 let results = idx
370 .search(T, "docs", "rust programming", 10, false)
371 .unwrap();
372 assert_eq!(results.len(), 1);
373 assert_eq!(results[0].doc_id, "d1");
374 }
375
376 #[test]
377 fn empty_query() {
378 let idx = make_index();
379 let results = idx.search(T, "docs", "the a is", 10, false).unwrap();
380 assert!(results.is_empty());
381 }
382
383 #[test]
384 fn collections_isolated() {
385 let idx = FtsIndex::new(MemoryBackend::new());
386 idx.index_document(T, "col_a", "d1", "alpha bravo charlie")
387 .unwrap();
388 idx.index_document(T, "col_b", "d1", "delta echo foxtrot")
389 .unwrap();
390
391 assert_eq!(idx.search(T, "col_a", "alpha", 10, false).unwrap().len(), 1);
392 assert!(
393 idx.search(T, "col_b", "alpha", 10, false)
394 .unwrap()
395 .is_empty()
396 );
397 }
398
399 #[test]
400 fn fuzzy_search() {
401 let idx = FtsIndex::new(MemoryBackend::new());
402 idx.index_document(T, "docs", "d1", "distributed database systems")
403 .unwrap();
404
405 let results = idx.search(T, "docs", "databse", 10, true).unwrap();
406 assert!(!results.is_empty());
407 assert!(results[0].fuzzy);
408 }
409
410 #[test]
411 fn phrase_boost_consecutive() {
412 let idx = FtsIndex::new(MemoryBackend::new());
413 idx.index_document(T, "docs", "d1", "the quick brown fox jumped")
414 .unwrap();
415 idx.index_document(T, "docs", "d2", "a brown dog chased a fox")
416 .unwrap();
417
418 let results = idx
419 .search_with_mode(T, "docs", "brown fox", 10, false, QueryMode::Or)
420 .unwrap();
421 assert!(results.len() >= 2);
422 assert_eq!(results[0].doc_id, "d1");
423 }
424
425 #[test]
426 fn phrase_boost_no_effect_single_term() {
427 let idx = FtsIndex::new(MemoryBackend::new());
428 idx.index_document(T, "docs", "d1", "hello world").unwrap();
429
430 let results = idx.search(T, "docs", "hello", 10, false).unwrap();
431 assert_eq!(results.len(), 1);
432 }
433
434 #[test]
435 fn tenants_isolated() {
436 let idx = FtsIndex::new(MemoryBackend::new());
437 idx.index_document(1, "docs", "d1", "alpha bravo").unwrap();
438 idx.index_document(2, "docs", "d1", "charlie delta")
439 .unwrap();
440
441 let r1 = idx.search(1, "docs", "alpha", 10, false).unwrap();
442 let r2 = idx.search(2, "docs", "alpha", 10, false).unwrap();
443 assert_eq!(r1.len(), 1);
444 assert!(r2.is_empty());
445 }
446}