Skip to main content

memory_indexer/search/
query.rs

1use std::{
2    cmp::Ordering,
3    collections::{HashMap, HashSet},
4};
5
6use smol_str::SmolStr;
7
8use super::{
9    super::{
10        ngram::{build_ngram_index, should_index_in_original_aux},
11        tokenizer::Token,
12        types::{
13            DocData, DocId, InMemoryIndex, IndexState, Posting, SearchMode, TermDomain, TermId,
14            domain_config,
15        },
16    },
17    MatchedTerm, SearchHit,
18    scoring::{
19        MIN_SHOULD_MATCH_RATIO, bm25_component, compute_min_should_match, has_minimum_should_match,
20        score_fuzzy_terms,
21    },
22};
23
24const PINYIN_FULL_PREFIX_MIN: usize = 2;
25const PINYIN_INITIALS_PREFIX_MIN: usize = 1;
26const PINYIN_PREFIX_MAX: usize = 16;
27
28struct TermView<'a> {
29    term_id: TermId,
30    term_text: String,
31    postings: &'a [Posting],
32    weight: f64,
33    domain: TermDomain,
34}
35
36impl InMemoryIndex {
37    /// Execute an auto-mode search and return doc ids with scores.
38    pub fn search(&self, index_name: &str, query: &str) -> Vec<(String, f64)> {
39        self.search_with_mode_hits(index_name, query, SearchMode::Auto)
40            .into_iter()
41            .map(|hit| (hit.doc_id, hit.score))
42            .collect()
43    }
44
45    /// Execute an auto-mode search and return full hits including matched terms.
46    pub fn search_hits(&self, index_name: &str, query: &str) -> Vec<SearchHit> {
47        self.search_with_mode_hits(index_name, query, SearchMode::Auto)
48    }
49
50    /// Execute a search in the specified mode and return doc ids with scores.
51    pub fn search_with_mode(
52        &self,
53        index_name: &str,
54        query: &str,
55        mode: SearchMode,
56    ) -> Vec<(String, f64)> {
57        self.search_with_mode_hits(index_name, query, mode)
58            .into_iter()
59            .map(|hit| (hit.doc_id, hit.score))
60            .collect()
61    }
62
63    /// Execute a search in the specified mode and return full hits including matched terms.
64    pub fn search_with_mode_hits(
65        &self,
66        index_name: &str,
67        query: &str,
68        mode: SearchMode,
69    ) -> Vec<SearchHit> {
70        if query == "*" || query.is_empty() {
71            if let Some(state) = self.indexes.get(index_name) {
72                return state
73                    .doc_index
74                    .keys()
75                    .map(|doc_id| SearchHit {
76                        doc_id: doc_id.to_string(),
77                        score: 1.0,
78                        matched_terms: Vec::new(),
79                    })
80                    .collect();
81            }
82            return vec![];
83        }
84
85        let query_terms = self.tokenize_query(query);
86        if query_terms.is_empty() {
87            return vec![];
88        }
89
90        match mode {
91            SearchMode::Exact => self.bm25_search(index_name, &query_terms, TermDomain::Original),
92            SearchMode::Pinyin => self.pinyin_search(index_name, &query_terms),
93            SearchMode::Fuzzy => self.fuzzy_search(index_name, &query_terms),
94            SearchMode::Auto => {
95                let exact = self.bm25_search(index_name, &query_terms, TermDomain::Original);
96                if has_minimum_should_match(&exact, query_terms.len()) {
97                    // Stop at exact-domain hits when they already satisfy recall, so we don't
98                    // dilute precision by falling through to fuzzier heuristics.
99                    return exact;
100                }
101
102                if !is_ascii_alphanumeric_query(&query_terms) {
103                    return self.fuzzy_search_internal(index_name, &query_terms, true);
104                }
105
106                let pinyin_prefix = self.pinyin_prefix_search(index_name, &query_terms);
107                if has_minimum_should_match(&pinyin_prefix, query_terms.len()) {
108                    return pinyin_prefix;
109                }
110
111                let pinyin_exact = self.pinyin_exact_search(index_name, &query_terms);
112                if has_minimum_should_match(&pinyin_exact, query_terms.len()) {
113                    return pinyin_exact;
114                }
115
116                if is_ascii_alphanumeric_query(&query_terms) {
117                    let fuzzy_original = self.fuzzy_search(index_name, &query_terms);
118                    if !fuzzy_original.is_empty() {
119                        return fuzzy_original;
120                    }
121                } else {
122                    let cjk_fuzzy = self.fuzzy_search_internal(index_name, &query_terms, true);
123                    if !cjk_fuzzy.is_empty() {
124                        return cjk_fuzzy;
125                    }
126                }
127
128                self.fuzzy_pinyin_search(index_name, &query_terms)
129            }
130        }
131    }
132
133    fn bm25_search(
134        &self,
135        index_name: &str,
136        query_terms: &[Token],
137        domain: TermDomain,
138    ) -> Vec<SearchHit> {
139        if query_terms.is_empty() {
140            return vec![];
141        }
142
143        let state = match self.indexes.get(index_name) {
144            Some(state) => state,
145            None => return vec![],
146        };
147
148        let domain_index = match state.domains.get(&domain) {
149            Some(idx) => idx,
150            None => return vec![],
151        };
152
153        let doc_count = state.doc_index.len();
154        if doc_count == 0 {
155            return vec![];
156        }
157
158        let mut term_views: Vec<TermView<'_>> = Vec::new();
159        let weight = domain_config(domain).weight;
160
161        for token in query_terms {
162            let Some(&term_id) = state.term_index.get(token.term.as_str()) else {
163                continue;
164            };
165            let Some(postings) = domain_index.postings.get(&term_id) else {
166                continue;
167            };
168            if postings.is_empty() {
169                continue;
170            }
171            let term_text = state
172                .terms
173                .get(term_id as usize)
174                .map(|term| term.as_str().to_string())
175                .unwrap_or_else(|| token.term.clone());
176            term_views.push(TermView {
177                term_id,
178                term_text,
179                postings,
180                weight,
181                domain,
182            });
183        }
184
185        if term_views.is_empty() {
186            return vec![];
187        }
188
189        let min_should_match =
190            compute_min_should_match(query_terms.len(), term_views.len(), MIN_SHOULD_MATCH_RATIO);
191
192        let n = doc_count as f64;
193        let avgdl = average_doc_len(state, domain, doc_count);
194
195        let mut idfs = HashMap::new();
196        for view in &term_views {
197            let n_q = view.postings.len() as f64;
198            let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
199            idfs.insert(view.term_id, idf);
200        }
201
202        let mut matches: HashMap<DocId, HashSet<MatchedTerm>> = HashMap::new();
203        let mut doc_scores: HashMap<DocId, f64> = HashMap::new();
204        for view in &term_views {
205            let idf = *idfs.get(&view.term_id).unwrap_or(&0.0);
206            for posting in view.postings {
207                let Some(doc_data) = state
208                    .docs
209                    .get(posting.doc as usize)
210                    .and_then(|doc| doc.as_ref())
211                else {
212                    continue;
213                };
214                let component = bm25_component(
215                    posting.freq as f64,
216                    doc_len_for_domain(doc_data, view.domain),
217                    avgdl,
218                    idf,
219                ) * view.weight;
220                if component > 0.0 {
221                    *doc_scores.entry(posting.doc).or_default() += component;
222                    matches
223                        .entry(posting.doc)
224                        .or_default()
225                        .insert(MatchedTerm::new(view.term_text.clone(), view.domain));
226                }
227            }
228        }
229
230        let mut scores: Vec<(DocId, f64)> = doc_scores
231            .into_iter()
232            .filter(|(doc_id, _)| {
233                matches
234                    .get(doc_id)
235                    .map(|set| set.len() >= min_should_match)
236                    .unwrap_or(false)
237            })
238            .collect();
239        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
240        scores
241            .into_iter()
242            .filter_map(|(doc_id, score)| {
243                let doc_name = state.doc_ids.get(doc_id as usize)?.to_string();
244                Some(SearchHit {
245                    doc_id: doc_name,
246                    score,
247                    matched_terms: matches
248                        .remove(&doc_id)
249                        .map(|s| s.into_iter().collect())
250                        .unwrap_or_default(),
251                })
252            })
253            .collect()
254    }
255
256    fn pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
257        if !is_ascii_alphanumeric_query(query_terms) {
258            return vec![];
259        }
260
261        let exact = self.pinyin_exact_search(index_name, query_terms);
262        if !exact.is_empty() {
263            return exact;
264        }
265
266        self.pinyin_prefix_search(index_name, query_terms)
267    }
268
269    fn pinyin_prefix_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
270        let full_prefix = self.prefix_search_in_domain(
271            index_name,
272            query_terms,
273            TermDomain::PinyinFull,
274            PINYIN_FULL_PREFIX_MIN,
275        );
276        if !full_prefix.is_empty() {
277            return full_prefix;
278        }
279
280        self.prefix_search_in_domain(
281            index_name,
282            query_terms,
283            TermDomain::PinyinInitials,
284            PINYIN_INITIALS_PREFIX_MIN,
285        )
286    }
287
288    fn pinyin_exact_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
289        let full = self.bm25_search(index_name, query_terms, TermDomain::PinyinFull);
290        if !full.is_empty() {
291            return full;
292        }
293
294        self.bm25_search(index_name, query_terms, TermDomain::PinyinInitials)
295    }
296
297    fn fuzzy_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
298        self.fuzzy_search_internal(index_name, query_terms, false)
299    }
300
301    fn fuzzy_search_internal(
302        &self,
303        index_name: &str,
304        query_terms: &[Token],
305        allow_non_ascii: bool,
306    ) -> Vec<SearchHit> {
307        self.fuzzy_search_in_domain(
308            index_name,
309            query_terms,
310            TermDomain::Original,
311            allow_non_ascii,
312        )
313    }
314
315    fn fuzzy_pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
316        if query_terms.is_empty() || !is_ascii_alphanumeric_query(query_terms) {
317            return vec![];
318        }
319
320        let full =
321            self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinFull, false);
322        if !full.is_empty() {
323            return full;
324        }
325
326        self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinInitials, false)
327    }
328
329    fn fuzzy_search_in_domain(
330        &self,
331        index_name: &str,
332        query_terms: &[Token],
333        domain: TermDomain,
334        allow_non_ascii: bool,
335    ) -> Vec<SearchHit> {
336        if query_terms.is_empty() || (!allow_non_ascii && !is_ascii_alphanumeric_query(query_terms))
337        {
338            return vec![];
339        }
340
341        if !domain_config(domain).allow_fuzzy {
342            return vec![];
343        }
344
345        let state = match self.indexes.get(index_name) {
346            Some(state) => state,
347            None => return vec![],
348        };
349
350        let domain_index = match state.domains.get(&domain) {
351            Some(idx) => idx,
352            None => return vec![],
353        };
354
355        let doc_count = state.doc_index.len();
356        if doc_count == 0 {
357            return vec![];
358        }
359
360        {
361            let mut aux = domain_index.aux.write().unwrap();
362            if aux.term_ids.is_none() {
363                let mut ids: Vec<TermId> = domain_index
364                    .postings
365                    .keys()
366                    .copied()
367                    .filter(|term_id| {
368                        if domain == TermDomain::Original {
369                            state
370                                .terms
371                                .get(*term_id as usize)
372                                .map(|term| should_index_in_original_aux(term.as_str()))
373                                .unwrap_or(false)
374                        } else {
375                            true
376                        }
377                    })
378                    .collect();
379                ids.sort_unstable();
380                aux.term_ids = Some(ids);
381            }
382            if aux.ngram_index.is_none() {
383                let ids = aux.term_ids.as_ref().unwrap();
384                aux.ngram_index = Some(build_ngram_index(ids, &state.terms));
385            }
386        }
387        let aux = domain_index.aux.read().unwrap();
388        let term_ids = aux.term_ids.as_ref().unwrap();
389        let ngram_index = aux.ngram_index.as_ref().unwrap();
390
391        let n = doc_count as f64;
392        let avgdl = average_doc_len(state, domain, doc_count);
393
394        let mut doc_scores: HashMap<DocId, f64> = HashMap::new();
395        let mut matched_terms: HashMap<DocId, HashSet<MatchedTerm>> = HashMap::new();
396        let weight = domain_config(domain).weight;
397        let mut matched_query_tokens: HashMap<DocId, HashSet<usize>> = HashMap::new();
398        let mut tokens_with_candidates: HashSet<usize> = HashSet::new();
399
400        for (idx, token) in query_terms.iter().enumerate() {
401            let exact_term = state.term_index.get(token.term.as_str()).copied();
402            score_fuzzy_terms(
403                &state.docs,
404                domain_index,
405                term_ids,
406                &state.terms,
407                ngram_index,
408                n,
409                avgdl,
410                &mut doc_scores,
411                &mut matched_terms,
412                &mut matched_query_tokens,
413                &mut tokens_with_candidates,
414                domain,
415                weight,
416                &token.term,
417                idx,
418                exact_term,
419            );
420        }
421
422        let available_terms = tokens_with_candidates.len();
423        let min_should_match =
424            // Only count query terms that actually produced fuzzy candidates; otherwise we
425            // would unfairly drop hits because of tokens with zero recall paths.
426            compute_min_should_match(query_terms.len(), available_terms, MIN_SHOULD_MATCH_RATIO);
427
428        let mut scores: Vec<(DocId, f64)> = doc_scores
429            .into_iter()
430            .filter(|(doc_id, _)| {
431                matched_query_tokens
432                    .get(doc_id)
433                    .map(|set| set.len() >= min_should_match)
434                    .unwrap_or(false)
435            })
436            .collect();
437        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
438        scores
439            .into_iter()
440            .filter_map(|(doc_id, score)| {
441                let doc_name = state.doc_ids.get(doc_id as usize)?.to_string();
442                Some(SearchHit {
443                    matched_terms: matched_terms
444                        .remove(&doc_id)
445                        .map(|s| s.into_iter().collect())
446                        .unwrap_or_default(),
447                    doc_id: doc_name,
448                    score,
449                })
450            })
451            .collect()
452    }
453
454    fn prefix_search_in_domain(
455        &self,
456        index_name: &str,
457        query_terms: &[Token],
458        domain: TermDomain,
459        min_prefix_len: usize,
460    ) -> Vec<SearchHit> {
461        if query_terms.is_empty() || !is_ascii_alphanumeric_query(query_terms) {
462            return vec![];
463        }
464
465        let state = match self.indexes.get(index_name) {
466            Some(state) => state,
467            None => return vec![],
468        };
469
470        let domain_index = match state.domains.get(&domain) {
471            Some(idx) => idx,
472            None => return vec![],
473        };
474
475        let doc_count = state.doc_index.len();
476        if doc_count == 0 {
477            return vec![];
478        }
479
480        {
481            let mut aux = domain_index.aux.write().unwrap();
482            if aux.term_ids.is_none() {
483                let mut ids: Vec<TermId> = domain_index.postings.keys().copied().collect();
484                ids.sort_unstable();
485                aux.term_ids = Some(ids);
486            }
487            if aux.prefix_index.is_none() {
488                let mut prefix_index: HashMap<SmolStr, Vec<TermId>> = HashMap::new();
489                let ids = aux.term_ids.as_ref().unwrap();
490                for &term_id in ids {
491                    let Some(term) = state.terms.get(term_id as usize) else {
492                        continue;
493                    };
494                    if !term.as_str().is_ascii() {
495                        continue;
496                    }
497                    let term_len = term.len();
498                    if term_len < min_prefix_len {
499                        continue;
500                    }
501                    let max = PINYIN_PREFIX_MAX.min(term_len);
502                    for len in min_prefix_len..=max {
503                        let prefix = SmolStr::new(&term.as_str()[..len]);
504                        prefix_index.entry(prefix).or_default().push(term_id);
505                    }
506                }
507                aux.prefix_index = Some(prefix_index);
508            }
509        }
510        let aux = domain_index.aux.read().unwrap();
511        let prefix_index = aux.prefix_index.as_ref().unwrap();
512
513        let n = doc_count as f64;
514        let avgdl = average_doc_len(state, domain, doc_count);
515
516        let mut doc_scores: HashMap<DocId, f64> = HashMap::new();
517        let mut matched_terms: HashMap<DocId, HashSet<MatchedTerm>> = HashMap::new();
518        let weight = domain_config(domain).weight;
519        let mut matched_query_tokens: HashMap<DocId, HashSet<usize>> = HashMap::new();
520        let mut tokens_with_candidates: HashSet<usize> = HashSet::new();
521
522        for (idx, token) in query_terms.iter().enumerate() {
523            if token.term.len() < min_prefix_len || token.term.len() > PINYIN_PREFIX_MAX {
524                continue;
525            }
526
527            let Some(candidates) = prefix_index.get(token.term.as_str()) else {
528                continue;
529            };
530            if candidates.is_empty() {
531                continue;
532            }
533
534            tokens_with_candidates.insert(idx);
535
536            for &candidate in candidates {
537                let Some(postings) = domain_index.postings.get(&candidate) else {
538                    continue;
539                };
540                if postings.is_empty() {
541                    continue;
542                }
543
544                let n_q = postings.len() as f64;
545                let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
546                let candidate_text = state
547                    .terms
548                    .get(candidate as usize)
549                    .map(|term| term.as_str().to_string())
550                    .unwrap_or_else(|| token.term.clone());
551
552                for posting in postings {
553                    let Some(doc_data) = state
554                        .docs
555                        .get(posting.doc as usize)
556                        .and_then(|doc| doc.as_ref())
557                    else {
558                        continue;
559                    };
560                    let term_score = bm25_component(
561                        posting.freq as f64,
562                        doc_len_for_domain(doc_data, domain),
563                        avgdl,
564                        idf,
565                    ) * weight;
566                    if term_score > 0.0 {
567                        *doc_scores.entry(posting.doc).or_default() += term_score;
568                        matched_terms
569                            .entry(posting.doc)
570                            .or_default()
571                            .insert(MatchedTerm::new(candidate_text.clone(), domain));
572                        matched_query_tokens
573                            .entry(posting.doc)
574                            .or_default()
575                            .insert(idx);
576                    }
577                }
578            }
579        }
580
581        let available_terms = tokens_with_candidates.len();
582        let min_should_match =
583            compute_min_should_match(query_terms.len(), available_terms, MIN_SHOULD_MATCH_RATIO);
584
585        let mut scores: Vec<(DocId, f64)> = doc_scores
586            .into_iter()
587            .filter(|(doc_id, _)| {
588                matched_query_tokens
589                    .get(doc_id)
590                    .map(|set| set.len() >= min_should_match)
591                    .unwrap_or(false)
592            })
593            .collect();
594        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
595        scores
596            .into_iter()
597            .filter_map(|(doc_id, score)| {
598                let doc_name = state.doc_ids.get(doc_id as usize)?.to_string();
599                Some(SearchHit {
600                    matched_terms: matched_terms
601                        .remove(&doc_id)
602                        .map(|s| s.into_iter().collect())
603                        .unwrap_or_default(),
604                    doc_id: doc_name,
605                    score,
606                })
607            })
608            .collect()
609    }
610}
611
612pub(super) fn is_ascii_alphanumeric_query(tokens: &[Token]) -> bool {
613    tokens
614        .iter()
615        .all(|token| token.term.chars().all(|c| c.is_ascii_alphanumeric()))
616}
617
618fn doc_len_for_domain(doc_data: &DocData, domain: TermDomain) -> f64 {
619    let len = doc_data.domain_doc_len.get(domain);
620    if len > 0 {
621        len as f64
622    } else {
623        doc_data.doc_len as f64
624    }
625}
626
627fn average_doc_len(state: &IndexState, domain: TermDomain, doc_count: usize) -> f64 {
628    if doc_count == 0 {
629        return 0.0;
630    }
631
632    let total = state.domain_total_len.get(domain);
633    if total <= 0 {
634        0.0
635    } else {
636        total as f64 / doc_count as f64
637    }
638}