memory_indexer/search/
query.rs

1use std::{
2    cmp::Ordering,
3    collections::{HashMap, HashSet},
4};
5
6use super::{
7    super::{
8        tokenizer::Token,
9        types::{DocData, InMemoryIndex, SearchMode, domain_config},
10    },
11    MatchedTerm, SearchHit, TermDomain,
12    scoring::{
13        MIN_SHOULD_MATCH_RATIO, bm25_component, compute_min_should_match, has_minimum_should_match,
14        score_fuzzy_terms,
15    },
16};
17
18struct TermView<'a> {
19    term: String,
20    postings: &'a HashMap<String, i64>,
21    weight: f64,
22    domain: TermDomain,
23}
24
25impl InMemoryIndex {
26    /// Execute an auto-mode search and return doc ids with scores.
27    pub fn search(&self, index_name: &str, query: &str) -> Vec<(String, f64)> {
28        self.search_with_mode_hits(index_name, query, SearchMode::Auto)
29            .into_iter()
30            .map(|hit| (hit.doc_id, hit.score))
31            .collect()
32    }
33
34    /// Execute an auto-mode search and return full hits including matched terms.
35    pub fn search_hits(&self, index_name: &str, query: &str) -> Vec<SearchHit> {
36        self.search_with_mode_hits(index_name, query, SearchMode::Auto)
37    }
38
39    /// Execute a search in the specified mode and return doc ids with scores.
40    pub fn search_with_mode(
41        &self,
42        index_name: &str,
43        query: &str,
44        mode: SearchMode,
45    ) -> Vec<(String, f64)> {
46        self.search_with_mode_hits(index_name, query, mode)
47            .into_iter()
48            .map(|hit| (hit.doc_id, hit.score))
49            .collect()
50    }
51
52    /// Execute a search in the specified mode and return full hits including matched terms.
53    pub fn search_with_mode_hits(
54        &self,
55        index_name: &str,
56        query: &str,
57        mode: SearchMode,
58    ) -> Vec<SearchHit> {
59        if query == "*" || query.is_empty() {
60            if let Some(docs) = self.docs.get(index_name) {
61                return docs
62                    .keys()
63                    .map(|k| SearchHit {
64                        doc_id: k.clone(),
65                        score: 1.0,
66                        matched_terms: Vec::new(),
67                    })
68                    .collect();
69            }
70            return vec![];
71        }
72
73        let query_terms = self.tokenize_query(query);
74        if query_terms.is_empty() {
75            return vec![];
76        }
77
78        match mode {
79            SearchMode::Exact => self.bm25_search(index_name, &query_terms, TermDomain::Original),
80            SearchMode::Pinyin => self.pinyin_search(index_name, &query_terms),
81            SearchMode::Fuzzy => self.fuzzy_search(index_name, &query_terms),
82            SearchMode::Auto => {
83                let exact = self.bm25_search(index_name, &query_terms, TermDomain::Original);
84                if has_minimum_should_match(&exact, query_terms.len()) {
85                    // Stop at exact-domain hits when they already satisfy recall, so we don't
86                    // dilute precision by falling through to fuzzier heuristics.
87                    return exact;
88                }
89
90                if !is_ascii_alphanumeric_query(&query_terms) {
91                    return self.fuzzy_search_internal(index_name, &query_terms, true);
92                }
93
94                let pinyin_prefix = self.pinyin_prefix_search(index_name, &query_terms);
95                if has_minimum_should_match(&pinyin_prefix, query_terms.len()) {
96                    return pinyin_prefix;
97                }
98
99                let pinyin_exact = self.pinyin_exact_search(index_name, &query_terms);
100                if has_minimum_should_match(&pinyin_exact, query_terms.len()) {
101                    return pinyin_exact;
102                }
103
104                if is_ascii_alphanumeric_query(&query_terms) {
105                    let fuzzy_original = self.fuzzy_search(index_name, &query_terms);
106                    if !fuzzy_original.is_empty() {
107                        return fuzzy_original;
108                    }
109                } else {
110                    let cjk_fuzzy = self.fuzzy_search_internal(index_name, &query_terms, true);
111                    if !cjk_fuzzy.is_empty() {
112                        return cjk_fuzzy;
113                    }
114                }
115
116                self.fuzzy_pinyin_search(index_name, &query_terms)
117            }
118        }
119    }
120
121    fn bm25_search(
122        &self,
123        index_name: &str,
124        query_terms: &[Token],
125        domain: TermDomain,
126    ) -> Vec<SearchHit> {
127        if query_terms.is_empty() {
128            return vec![];
129        }
130
131        let domains = match self.domains.get(index_name) {
132            Some(d) => d,
133            None => return vec![],
134        };
135
136        let domain_index = match domains.get(&domain) {
137            Some(idx) => idx,
138            None => return vec![],
139        };
140
141        let docs = match self.docs.get(index_name) {
142            Some(d) => d,
143            None => return vec![],
144        };
145
146        let mut term_views: Vec<TermView<'_>> = Vec::new();
147        let weight = domain_config(domain).weight;
148
149        for token in query_terms {
150            let Some(doc_map) = domain_index.postings.get(&token.term) else {
151                continue;
152            };
153
154            if doc_map.is_empty() {
155                continue;
156            }
157
158            term_views.push(TermView {
159                term: token.term.clone(),
160                postings: doc_map,
161                weight,
162                domain,
163            });
164        }
165
166        if term_views.is_empty() {
167            return vec![];
168        }
169
170        let min_should_match =
171            compute_min_should_match(query_terms.len(), term_views.len(), MIN_SHOULD_MATCH_RATIO);
172
173        let n = docs.len() as f64;
174        if n <= 0.0 {
175            return vec![];
176        }
177        let avgdl = average_doc_len(self, index_name, domain, docs.len());
178
179        let mut idfs = HashMap::new();
180        for view in &term_views {
181            let n_q = view.postings.len() as f64;
182            let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
183            idfs.insert(view.term.clone(), idf);
184        }
185
186        let mut matches: HashMap<String, HashSet<MatchedTerm>> = HashMap::new();
187        let mut doc_scores: HashMap<String, f64> = HashMap::new();
188        for view in &term_views {
189            for (doc_id, freq) in view.postings {
190                let Some(doc_data) = docs.get(doc_id) else {
191                    continue;
192                };
193                let idf = *idfs.get(&view.term).unwrap_or(&0.0);
194                let component = bm25_component(
195                    *freq as f64,
196                    doc_len_for_domain(doc_data, view.domain),
197                    avgdl,
198                    idf,
199                ) * view.weight;
200                if component > 0.0 {
201                    *doc_scores.entry(doc_id.clone()).or_default() += component;
202                    matches
203                        .entry(doc_id.clone())
204                        .or_default()
205                        .insert(MatchedTerm::new(view.term.clone(), view.domain));
206                }
207            }
208        }
209
210        let mut scores: Vec<(String, f64)> = doc_scores
211            .into_iter()
212            .filter(|(doc_id, _)| {
213                matches
214                    .get(doc_id)
215                    .map(|set| set.len() >= min_should_match)
216                    .unwrap_or(false)
217            })
218            .collect();
219        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
220        scores
221            .into_iter()
222            .map(|(doc_id, score)| SearchHit {
223                doc_id: doc_id.clone(),
224                score,
225                matched_terms: matches
226                    .remove(&doc_id)
227                    .map(|s| s.into_iter().collect())
228                    .unwrap_or_default(),
229            })
230            .collect()
231    }
232
233    fn pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
234        if !is_ascii_alphanumeric_query(query_terms) {
235            return vec![];
236        }
237
238        let exact = self.pinyin_exact_search(index_name, query_terms);
239        if !exact.is_empty() {
240            return exact;
241        }
242
243        self.pinyin_prefix_search(index_name, query_terms)
244    }
245
246    fn pinyin_prefix_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
247        let full_prefix = self.bm25_search(index_name, query_terms, TermDomain::PinyinFullPrefix);
248        if !full_prefix.is_empty() {
249            return full_prefix;
250        }
251
252        self.bm25_search(index_name, query_terms, TermDomain::PinyinInitialsPrefix)
253    }
254
255    fn pinyin_exact_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
256        let full = self.bm25_search(index_name, query_terms, TermDomain::PinyinFull);
257        if !full.is_empty() {
258            return full;
259        }
260
261        self.bm25_search(index_name, query_terms, TermDomain::PinyinInitials)
262    }
263
264    fn fuzzy_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
265        self.fuzzy_search_internal(index_name, query_terms, false)
266    }
267
268    fn fuzzy_search_internal(
269        &self,
270        index_name: &str,
271        query_terms: &[Token],
272        allow_non_ascii: bool,
273    ) -> Vec<SearchHit> {
274        self.fuzzy_search_in_domain(
275            index_name,
276            query_terms,
277            TermDomain::Original,
278            allow_non_ascii,
279        )
280    }
281
282    fn fuzzy_pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
283        if query_terms.is_empty() || !is_ascii_alphanumeric_query(query_terms) {
284            return vec![];
285        }
286
287        let full =
288            self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinFull, false);
289        if !full.is_empty() {
290            return full;
291        }
292
293        self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinInitials, false)
294    }
295
296    fn fuzzy_search_in_domain(
297        &self,
298        index_name: &str,
299        query_terms: &[Token],
300        domain: TermDomain,
301        allow_non_ascii: bool,
302    ) -> Vec<SearchHit> {
303        if query_terms.is_empty() || (!allow_non_ascii && !is_ascii_alphanumeric_query(query_terms))
304        {
305            return vec![];
306        }
307
308        if !domain_config(domain).allow_fuzzy {
309            return vec![];
310        }
311
312        let docs = match self.docs.get(index_name) {
313            Some(d) => d,
314            None => return vec![],
315        };
316
317        let domains = match self.domains.get(index_name) {
318            Some(d) => d,
319            None => return vec![],
320        };
321        let domain_index = match domains.get(&domain) {
322            Some(idx) => idx,
323            None => return vec![],
324        };
325
326        let n = docs.len() as f64;
327        if n <= 0.0 {
328            return vec![];
329        }
330        let avgdl = average_doc_len(self, index_name, domain, docs.len());
331
332        let mut doc_scores: HashMap<String, f64> = HashMap::new();
333        let mut matched_terms: HashMap<String, HashSet<MatchedTerm>> = HashMap::new();
334        let weight = domain_config(domain).weight;
335        let mut matched_query_tokens: HashMap<String, HashSet<usize>> = HashMap::new();
336        let mut tokens_with_candidates: HashSet<usize> = HashSet::new();
337
338        for (idx, token) in query_terms.iter().enumerate() {
339            score_fuzzy_terms(
340                docs,
341                domain_index,
342                n,
343                avgdl,
344                &mut doc_scores,
345                &mut matched_terms,
346                &mut matched_query_tokens,
347                &mut tokens_with_candidates,
348                domain,
349                weight,
350                &token.term,
351                &|doc_data| doc_len_for_domain(doc_data, domain),
352                idx,
353            );
354        }
355
356        let available_terms = tokens_with_candidates.len();
357        let min_should_match =
358            // Only count query terms that actually produced fuzzy candidates; otherwise we
359            // would unfairly drop hits because of tokens with zero recall paths.
360            compute_min_should_match(query_terms.len(), available_terms, MIN_SHOULD_MATCH_RATIO);
361
362        let mut scores: Vec<(String, f64)> = doc_scores
363            .into_iter()
364            .filter(|(doc_id, _)| {
365                matched_query_tokens
366                    .get(doc_id)
367                    .map(|set| set.len() >= min_should_match)
368                    .unwrap_or(false)
369            })
370            .collect();
371        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
372        scores
373            .into_iter()
374            .map(|(doc_id, score)| SearchHit {
375                matched_terms: matched_terms
376                    .remove(&doc_id)
377                    .map(|s| s.into_iter().collect())
378                    .unwrap_or_default(),
379                doc_id,
380                score,
381            })
382            .collect()
383    }
384}
385
386pub(super) fn is_ascii_alphanumeric_query(tokens: &[Token]) -> bool {
387    tokens
388        .iter()
389        .all(|token| token.term.chars().all(|c| c.is_ascii_alphanumeric()))
390}
391
392fn doc_len_for_domain(doc_data: &DocData, domain: TermDomain) -> f64 {
393    if domain.is_prefix() {
394        // Prefix domains reuse positions but skip length normalization so short prefixes
395        // are not penalized compared to full tokens.
396        return 0.0;
397    }
398
399    let len = doc_data.domain_doc_len.get(domain);
400    if len > 0 {
401        len as f64
402    } else {
403        doc_data.doc_len as f64
404    }
405}
406
407fn average_doc_len(
408    index: &InMemoryIndex,
409    index_name: &str,
410    domain: TermDomain,
411    doc_count: usize,
412) -> f64 {
413    if domain.is_prefix() || doc_count == 0 {
414        return 0.0;
415    }
416
417    let total = index
418        .domain_total_lens
419        .get(index_name)
420        .map(|m| m.get(domain))
421        .unwrap_or(0);
422    if total <= 0 {
423        0.0
424    } else {
425        total as f64 / doc_count as f64
426    }
427}