1use std::collections::HashMap;
4
5use crate::backend::FtsBackend;
6use crate::bm25::bm25_score;
7use crate::index::FtsIndex;
8use crate::posting::{Posting, QueryMode, TextSearchResult};
9use crate::search::phrase;
10
11impl<B: FtsBackend> FtsIndex<B> {
12 pub fn search(
18 &self,
19 collection: &str,
20 query: &str,
21 top_k: usize,
22 fuzzy_enabled: bool,
23 ) -> Result<Vec<TextSearchResult>, B::Error> {
24 self.search_with_mode(collection, query, top_k, fuzzy_enabled, QueryMode::And)
25 }
26
27 pub fn search_with_mode(
35 &self,
36 collection: &str,
37 query: &str,
38 top_k: usize,
39 fuzzy_enabled: bool,
40 mode: QueryMode,
41 ) -> Result<Vec<TextSearchResult>, B::Error> {
42 let query_tokens = self.analyze_for_collection(collection, query)?;
43 if query_tokens.is_empty() {
44 return Ok(Vec::new());
45 }
46 let num_query_terms = query_tokens.len();
47
48 let raw_tokens = if fuzzy_enabled {
51 self.tokenize_raw_for_collection(collection, query)?
52 } else {
53 Vec::new()
54 };
55
56 let (total_docs, avg_doc_len) = self.index_stats(collection)?;
57 if total_docs == 0 {
58 return Ok(Vec::new());
59 }
60
61 let bmw_params = super::bmw::query::BmwParams {
63 query_tokens: &query_tokens,
64 raw_tokens: &raw_tokens,
65 fuzzy_enabled,
66 top_k: if mode == QueryMode::And && num_query_terms > 1 {
67 top_k.saturating_mul(3).max(20)
68 } else {
69 top_k
70 },
71 total_docs,
72 avg_doc_len,
73 bm25: &self.bm25_params,
74 };
75 if let Ok(Some(bmw_results)) = super::bmw::query::bmw_search(self, collection, &bmw_params)
76 {
77 if mode == QueryMode::Or || num_query_terms == 1 {
78 return Ok(bmw_results.into_iter().take(top_k).collect());
79 }
80
81 let and_results =
84 self.filter_and_mode(collection, &query_tokens, &bmw_results, num_query_terms)?;
85
86 if !and_results.is_empty() {
87 return Ok(and_results.into_iter().take(top_k).collect());
88 }
89
90 let penalized: Vec<TextSearchResult> = bmw_results
92 .into_iter()
93 .map(|mut r| {
94 let matched = self.count_term_matches(collection, &query_tokens, &r.doc_id);
95 let coverage = matched as f32 / num_query_terms as f32;
96 r.score *= coverage;
97 r
98 })
99 .collect();
100 let mut sorted = penalized;
101 sorted.sort_by(|a, b| {
102 b.score
103 .partial_cmp(&a.score)
104 .unwrap_or(std::cmp::Ordering::Equal)
105 });
106 sorted.truncate(top_k);
107 return Ok(sorted);
108 }
109
110 let mut term_postings: Vec<(Vec<Posting>, bool)> = Vec::with_capacity(num_query_terms);
114 for (i, token) in query_tokens.iter().enumerate() {
115 let postings = self.backend.read_postings(collection, token)?;
116 if !postings.is_empty() {
117 term_postings.push((postings, false));
118 } else if fuzzy_enabled {
119 let raw = raw_tokens
121 .get(i)
122 .map(String::as_str)
123 .unwrap_or(token.as_str());
124 let (fuzzy_posts, is_fuzzy) = self.fuzzy_lookup(collection, raw)?;
125 term_postings.push((fuzzy_posts, is_fuzzy));
126 } else {
127 term_postings.push((Vec::new(), false));
128 }
129 }
130
131 let mut doc_scores: HashMap<String, (f32, bool, usize)> = HashMap::new();
134
135 for (token_idx, (postings, is_fuzzy)) in term_postings.iter().enumerate() {
136 if postings.is_empty() {
137 continue;
138 }
139 let df = postings.len() as u32;
140
141 for posting in postings {
142 let doc_len = self
143 .backend
144 .read_doc_length(collection, &posting.doc_id)?
145 .unwrap_or(1);
146
147 let mut score = bm25_score(
148 posting.term_freq,
149 df,
150 doc_len,
151 total_docs,
152 avg_doc_len,
153 &self.bm25_params,
154 );
155
156 if *is_fuzzy {
157 score *= crate::fuzzy::fuzzy_discount(1);
158 }
159
160 let entry = doc_scores
161 .entry(posting.doc_id.clone())
162 .or_insert((0.0, false, 0));
163 entry.0 += score;
164 if *is_fuzzy {
165 entry.1 = true;
166 }
167 entry.2 += 1;
168 }
169 let _ = token_idx; }
171
172 if num_query_terms >= 2 {
174 let doc_postings_map =
175 phrase::collect_doc_postings(&query_tokens, &term_postings, &self.backend);
176 for (doc_id, token_postings) in &doc_postings_map {
177 if let Some(entry) = doc_scores.get_mut(doc_id.as_str()) {
178 let boost = phrase::phrase_boost(&query_tokens, token_postings);
179 entry.0 *= boost;
180 }
181 }
182 }
183
184 if mode == QueryMode::And && num_query_terms > 1 {
186 let and_results: HashMap<String, (f32, bool, usize)> = doc_scores
187 .iter()
188 .filter(|(_, (_, _, match_count))| *match_count >= num_query_terms)
189 .map(|(k, v)| (k.clone(), *v))
190 .collect();
191
192 if !and_results.is_empty() {
193 return Ok(Self::to_sorted_results(and_results, top_k));
194 }
195
196 for (score, _, match_count) in doc_scores.values_mut() {
198 let coverage = *match_count as f32 / num_query_terms as f32;
199 *score *= coverage;
200 }
201 }
202
203 Ok(Self::to_sorted_results(doc_scores, top_k))
204 }
205
206 fn filter_and_mode(
208 &self,
209 collection: &str,
210 query_tokens: &[String],
211 candidates: &[TextSearchResult],
212 num_terms: usize,
213 ) -> Result<Vec<TextSearchResult>, B::Error> {
214 let doc_map = self.load_doc_id_map(collection)?;
215 let term_blocks = crate::lsm::query::collect_merged_term_blocks(
216 &self.backend,
217 collection,
218 self.memtable(),
219 query_tokens,
220 )?;
221
222 let mut results = Vec::new();
223 for candidate in candidates {
224 let int_id = doc_map.to_u32(&candidate.doc_id);
225 let matched = term_blocks
226 .iter()
227 .filter(|tb| {
228 int_id.is_some_and(|id| tb.blocks.iter().any(|b| b.doc_ids.contains(&id)))
229 })
230 .count();
231 if matched >= num_terms {
232 results.push(candidate.clone());
233 }
234 }
235 Ok(results)
236 }
237
238 fn count_term_matches(&self, collection: &str, query_tokens: &[String], doc_id: &str) -> usize {
240 let doc_map = match self.load_doc_id_map(collection) {
241 Ok(m) => m,
242 Err(_) => return 0,
243 };
244 let Some(int_id) = doc_map.to_u32(doc_id) else {
245 return 0;
246 };
247 let term_blocks = match crate::lsm::query::collect_merged_term_blocks(
248 &self.backend,
249 collection,
250 self.memtable(),
251 query_tokens,
252 ) {
253 Ok(tb) => tb,
254 Err(_) => return 0,
255 };
256 term_blocks
257 .iter()
258 .filter(|tb| tb.blocks.iter().any(|b| b.doc_ids.contains(&int_id)))
259 .count()
260 }
261
262 fn to_sorted_results(
264 doc_scores: HashMap<String, (f32, bool, usize)>,
265 top_k: usize,
266 ) -> Vec<TextSearchResult> {
267 let mut results: Vec<TextSearchResult> = doc_scores
268 .into_iter()
269 .map(|(doc_id, (score, fuzzy_flag, _))| TextSearchResult {
270 doc_id,
271 score,
272 fuzzy: fuzzy_flag,
273 })
274 .collect();
275 results.sort_by(|a, b| {
276 b.score
277 .partial_cmp(&a.score)
278 .unwrap_or(std::cmp::Ordering::Equal)
279 });
280 results.truncate(top_k);
281 results
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use crate::backend::memory::MemoryBackend;
288 use crate::index::FtsIndex;
289 use crate::posting::QueryMode;
290
291 fn make_index() -> FtsIndex<MemoryBackend> {
292 let idx = FtsIndex::new(MemoryBackend::new());
293 idx.index_document("docs", "d1", "The quick brown fox jumps over the lazy dog")
294 .unwrap();
295 idx.index_document("docs", "d2", "A fast brown dog runs across the field")
296 .unwrap();
297 idx.index_document("docs", "d3", "Rust programming language for systems")
298 .unwrap();
299 idx
300 }
301
302 #[test]
303 fn basic_search() {
304 let idx = make_index();
305 let results = idx.search("docs", "brown fox", 10, false).unwrap();
306 assert!(!results.is_empty());
307 assert_eq!(results[0].doc_id, "d1");
308 }
309
310 #[test]
311 fn search_with_stemming() {
312 let idx = FtsIndex::new(MemoryBackend::new());
313 idx.index_document("docs", "d1", "running distributed databases")
314 .unwrap();
315 idx.index_document("docs", "d2", "the cat sat on a mat")
316 .unwrap();
317
318 let results = idx
319 .search("docs", "database distribution", 10, false)
320 .unwrap();
321 assert!(!results.is_empty());
322 assert_eq!(results[0].doc_id, "d1");
323 }
324
325 #[test]
326 fn or_mode() {
327 let idx = make_index();
328 let results = idx
329 .search_with_mode("docs", "brown fox", 10, false, QueryMode::Or)
330 .unwrap();
331 assert!(results.len() >= 2);
332 }
333
334 #[test]
335 fn and_mode_filters() {
336 let idx = FtsIndex::new(MemoryBackend::new());
337 idx.index_document("docs", "d1", "Rust programming language")
338 .unwrap();
339 idx.index_document("docs", "d2", "Python programming language")
340 .unwrap();
341
342 let results = idx
343 .search_with_mode("docs", "rust programming", 10, false, QueryMode::And)
344 .unwrap();
345 assert_eq!(results.len(), 1);
346 assert_eq!(results[0].doc_id, "d1");
347 }
348
349 #[test]
350 fn and_fallback_to_or() {
351 let idx = FtsIndex::new(MemoryBackend::new());
352 idx.index_document("docs", "d1", "rust programming language")
353 .unwrap();
354 idx.index_document("docs", "d2", "python programming language")
355 .unwrap();
356
357 let results = idx.search("docs", "rust python", 10, false).unwrap();
359 assert_eq!(results.len(), 2);
360 for r in &results {
362 assert!(r.score > 0.0);
363 }
364 }
365
366 #[test]
367 fn and_no_fallback_when_results_exist() {
368 let idx = FtsIndex::new(MemoryBackend::new());
369 idx.index_document("docs", "d1", "rust programming language")
370 .unwrap();
371 idx.index_document("docs", "d2", "python programming language")
372 .unwrap();
373
374 let results = idx.search("docs", "rust programming", 10, false).unwrap();
376 assert_eq!(results.len(), 1);
377 assert_eq!(results[0].doc_id, "d1");
378 }
379
380 #[test]
381 fn empty_query() {
382 let idx = make_index();
383 let results = idx.search("docs", "the a is", 10, false).unwrap();
384 assert!(results.is_empty());
385 }
386
387 #[test]
388 fn collections_isolated() {
389 let idx = FtsIndex::new(MemoryBackend::new());
390 idx.index_document("col_a", "d1", "alpha bravo charlie")
391 .unwrap();
392 idx.index_document("col_b", "d1", "delta echo foxtrot")
393 .unwrap();
394
395 assert_eq!(idx.search("col_a", "alpha", 10, false).unwrap().len(), 1);
396 assert!(idx.search("col_b", "alpha", 10, false).unwrap().is_empty());
397 }
398
399 #[test]
400 fn fuzzy_search() {
401 let idx = FtsIndex::new(MemoryBackend::new());
402 idx.index_document("docs", "d1", "distributed database systems")
403 .unwrap();
404
405 let results = idx.search("docs", "databse", 10, true).unwrap();
408 assert!(!results.is_empty());
409 assert!(results[0].fuzzy);
410 }
411
412 #[test]
413 fn phrase_boost_consecutive() {
414 let idx = FtsIndex::new(MemoryBackend::new());
415 idx.index_document("docs", "d1", "the quick brown fox jumped")
417 .unwrap();
418 idx.index_document("docs", "d2", "a brown dog chased a fox")
420 .unwrap();
421
422 let results = idx
423 .search_with_mode("docs", "brown fox", 10, false, QueryMode::Or)
424 .unwrap();
425 assert!(results.len() >= 2);
426 assert_eq!(results[0].doc_id, "d1");
428 }
429
430 #[test]
431 fn phrase_boost_no_effect_single_term() {
432 let idx = FtsIndex::new(MemoryBackend::new());
433 idx.index_document("docs", "d1", "hello world").unwrap();
434
435 let results = idx.search("docs", "hello", 10, false).unwrap();
436 assert_eq!(results.len(), 1);
437 }
438}