1use std::{
2 cmp::Ordering,
3 collections::{HashMap, HashSet},
4};
5
6use super::{
7 super::{
8 tokenizer::Token,
9 types::{DocData, InMemoryIndex, SearchMode, domain_config},
10 },
11 MatchedTerm, SearchHit, TermDomain,
12 scoring::{
13 MIN_SHOULD_MATCH_RATIO, bm25_component, compute_min_should_match, has_minimum_should_match,
14 score_fuzzy_terms,
15 },
16};
17
18struct TermView<'a> {
19 term: String,
20 postings: &'a HashMap<String, i64>,
21 weight: f64,
22 domain: TermDomain,
23}
24
25impl InMemoryIndex {
26 pub fn search(&self, index_name: &str, query: &str) -> Vec<(String, f64)> {
28 self.search_with_mode_hits(index_name, query, SearchMode::Auto)
29 .into_iter()
30 .map(|hit| (hit.doc_id, hit.score))
31 .collect()
32 }
33
34 pub fn search_hits(&self, index_name: &str, query: &str) -> Vec<SearchHit> {
36 self.search_with_mode_hits(index_name, query, SearchMode::Auto)
37 }
38
39 pub fn search_with_mode(
41 &self,
42 index_name: &str,
43 query: &str,
44 mode: SearchMode,
45 ) -> Vec<(String, f64)> {
46 self.search_with_mode_hits(index_name, query, mode)
47 .into_iter()
48 .map(|hit| (hit.doc_id, hit.score))
49 .collect()
50 }
51
52 pub fn search_with_mode_hits(
54 &self,
55 index_name: &str,
56 query: &str,
57 mode: SearchMode,
58 ) -> Vec<SearchHit> {
59 if query == "*" || query.is_empty() {
60 if let Some(docs) = self.docs.get(index_name) {
61 return docs
62 .keys()
63 .map(|k| SearchHit {
64 doc_id: k.clone(),
65 score: 1.0,
66 matched_terms: Vec::new(),
67 })
68 .collect();
69 }
70 return vec![];
71 }
72
73 let query_terms = self.tokenize_query(query);
74 if query_terms.is_empty() {
75 return vec![];
76 }
77
78 match mode {
79 SearchMode::Exact => self.bm25_search(index_name, &query_terms, TermDomain::Original),
80 SearchMode::Pinyin => self.pinyin_search(index_name, &query_terms),
81 SearchMode::Fuzzy => self.fuzzy_search(index_name, &query_terms),
82 SearchMode::Auto => {
83 let exact = self.bm25_search(index_name, &query_terms, TermDomain::Original);
84 if has_minimum_should_match(&exact, query_terms.len()) {
85 return exact;
88 }
89
90 if !is_ascii_alphanumeric_query(&query_terms) {
91 return self.fuzzy_search_internal(index_name, &query_terms, true);
92 }
93
94 let pinyin_prefix = self.pinyin_prefix_search(index_name, &query_terms);
95 if has_minimum_should_match(&pinyin_prefix, query_terms.len()) {
96 return pinyin_prefix;
97 }
98
99 let pinyin_exact = self.pinyin_exact_search(index_name, &query_terms);
100 if has_minimum_should_match(&pinyin_exact, query_terms.len()) {
101 return pinyin_exact;
102 }
103
104 if is_ascii_alphanumeric_query(&query_terms) {
105 let fuzzy_original = self.fuzzy_search(index_name, &query_terms);
106 if !fuzzy_original.is_empty() {
107 return fuzzy_original;
108 }
109 } else {
110 let cjk_fuzzy = self.fuzzy_search_internal(index_name, &query_terms, true);
111 if !cjk_fuzzy.is_empty() {
112 return cjk_fuzzy;
113 }
114 }
115
116 self.fuzzy_pinyin_search(index_name, &query_terms)
117 }
118 }
119 }
120
121 fn bm25_search(
122 &self,
123 index_name: &str,
124 query_terms: &[Token],
125 domain: TermDomain,
126 ) -> Vec<SearchHit> {
127 if query_terms.is_empty() {
128 return vec![];
129 }
130
131 let domains = match self.domains.get(index_name) {
132 Some(d) => d,
133 None => return vec![],
134 };
135
136 let domain_index = match domains.get(&domain) {
137 Some(idx) => idx,
138 None => return vec![],
139 };
140
141 let docs = match self.docs.get(index_name) {
142 Some(d) => d,
143 None => return vec![],
144 };
145
146 let mut term_views: Vec<TermView<'_>> = Vec::new();
147 let weight = domain_config(domain).weight;
148
149 for token in query_terms {
150 let Some(doc_map) = domain_index.postings.get(&token.term) else {
151 continue;
152 };
153
154 if doc_map.is_empty() {
155 continue;
156 }
157
158 term_views.push(TermView {
159 term: token.term.clone(),
160 postings: doc_map,
161 weight,
162 domain,
163 });
164 }
165
166 if term_views.is_empty() {
167 return vec![];
168 }
169
170 let min_should_match =
171 compute_min_should_match(query_terms.len(), term_views.len(), MIN_SHOULD_MATCH_RATIO);
172
173 let n = docs.len() as f64;
174 if n <= 0.0 {
175 return vec![];
176 }
177 let avgdl = average_doc_len(self, index_name, domain, docs.len());
178
179 let mut idfs = HashMap::new();
180 for view in &term_views {
181 let n_q = view.postings.len() as f64;
182 let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
183 idfs.insert(view.term.clone(), idf);
184 }
185
186 let mut matches: HashMap<String, HashSet<MatchedTerm>> = HashMap::new();
187 let mut doc_scores: HashMap<String, f64> = HashMap::new();
188 for view in &term_views {
189 for (doc_id, freq) in view.postings {
190 let Some(doc_data) = docs.get(doc_id) else {
191 continue;
192 };
193 let idf = *idfs.get(&view.term).unwrap_or(&0.0);
194 let component = bm25_component(
195 *freq as f64,
196 doc_len_for_domain(doc_data, view.domain),
197 avgdl,
198 idf,
199 ) * view.weight;
200 if component > 0.0 {
201 *doc_scores.entry(doc_id.clone()).or_default() += component;
202 matches
203 .entry(doc_id.clone())
204 .or_default()
205 .insert(MatchedTerm::new(view.term.clone(), view.domain));
206 }
207 }
208 }
209
210 let mut scores: Vec<(String, f64)> = doc_scores
211 .into_iter()
212 .filter(|(doc_id, _)| {
213 matches
214 .get(doc_id)
215 .map(|set| set.len() >= min_should_match)
216 .unwrap_or(false)
217 })
218 .collect();
219 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
220 scores
221 .into_iter()
222 .map(|(doc_id, score)| SearchHit {
223 doc_id: doc_id.clone(),
224 score,
225 matched_terms: matches
226 .remove(&doc_id)
227 .map(|s| s.into_iter().collect())
228 .unwrap_or_default(),
229 })
230 .collect()
231 }
232
233 fn pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
234 if !is_ascii_alphanumeric_query(query_terms) {
235 return vec![];
236 }
237
238 let exact = self.pinyin_exact_search(index_name, query_terms);
239 if !exact.is_empty() {
240 return exact;
241 }
242
243 self.pinyin_prefix_search(index_name, query_terms)
244 }
245
246 fn pinyin_prefix_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
247 let full_prefix = self.bm25_search(index_name, query_terms, TermDomain::PinyinFullPrefix);
248 if !full_prefix.is_empty() {
249 return full_prefix;
250 }
251
252 self.bm25_search(index_name, query_terms, TermDomain::PinyinInitialsPrefix)
253 }
254
255 fn pinyin_exact_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
256 let full = self.bm25_search(index_name, query_terms, TermDomain::PinyinFull);
257 if !full.is_empty() {
258 return full;
259 }
260
261 self.bm25_search(index_name, query_terms, TermDomain::PinyinInitials)
262 }
263
264 fn fuzzy_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
265 self.fuzzy_search_internal(index_name, query_terms, false)
266 }
267
268 fn fuzzy_search_internal(
269 &self,
270 index_name: &str,
271 query_terms: &[Token],
272 allow_non_ascii: bool,
273 ) -> Vec<SearchHit> {
274 self.fuzzy_search_in_domain(
275 index_name,
276 query_terms,
277 TermDomain::Original,
278 allow_non_ascii,
279 )
280 }
281
282 fn fuzzy_pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
283 if query_terms.is_empty() || !is_ascii_alphanumeric_query(query_terms) {
284 return vec![];
285 }
286
287 let full =
288 self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinFull, false);
289 if !full.is_empty() {
290 return full;
291 }
292
293 self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinInitials, false)
294 }
295
296 fn fuzzy_search_in_domain(
297 &self,
298 index_name: &str,
299 query_terms: &[Token],
300 domain: TermDomain,
301 allow_non_ascii: bool,
302 ) -> Vec<SearchHit> {
303 if query_terms.is_empty() || (!allow_non_ascii && !is_ascii_alphanumeric_query(query_terms))
304 {
305 return vec![];
306 }
307
308 if !domain_config(domain).allow_fuzzy {
309 return vec![];
310 }
311
312 let docs = match self.docs.get(index_name) {
313 Some(d) => d,
314 None => return vec![],
315 };
316
317 let domains = match self.domains.get(index_name) {
318 Some(d) => d,
319 None => return vec![],
320 };
321 let domain_index = match domains.get(&domain) {
322 Some(idx) => idx,
323 None => return vec![],
324 };
325
326 let n = docs.len() as f64;
327 if n <= 0.0 {
328 return vec![];
329 }
330 let avgdl = average_doc_len(self, index_name, domain, docs.len());
331
332 let mut doc_scores: HashMap<String, f64> = HashMap::new();
333 let mut matched_terms: HashMap<String, HashSet<MatchedTerm>> = HashMap::new();
334 let weight = domain_config(domain).weight;
335 let mut matched_query_tokens: HashMap<String, HashSet<usize>> = HashMap::new();
336 let mut tokens_with_candidates: HashSet<usize> = HashSet::new();
337
338 for (idx, token) in query_terms.iter().enumerate() {
339 score_fuzzy_terms(
340 docs,
341 domain_index,
342 n,
343 avgdl,
344 &mut doc_scores,
345 &mut matched_terms,
346 &mut matched_query_tokens,
347 &mut tokens_with_candidates,
348 domain,
349 weight,
350 &token.term,
351 &|doc_data| doc_len_for_domain(doc_data, domain),
352 idx,
353 );
354 }
355
356 let available_terms = tokens_with_candidates.len();
357 let min_should_match =
358 compute_min_should_match(query_terms.len(), available_terms, MIN_SHOULD_MATCH_RATIO);
361
362 let mut scores: Vec<(String, f64)> = doc_scores
363 .into_iter()
364 .filter(|(doc_id, _)| {
365 matched_query_tokens
366 .get(doc_id)
367 .map(|set| set.len() >= min_should_match)
368 .unwrap_or(false)
369 })
370 .collect();
371 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
372 scores
373 .into_iter()
374 .map(|(doc_id, score)| SearchHit {
375 matched_terms: matched_terms
376 .remove(&doc_id)
377 .map(|s| s.into_iter().collect())
378 .unwrap_or_default(),
379 doc_id,
380 score,
381 })
382 .collect()
383 }
384}
385
386pub(super) fn is_ascii_alphanumeric_query(tokens: &[Token]) -> bool {
387 tokens
388 .iter()
389 .all(|token| token.term.chars().all(|c| c.is_ascii_alphanumeric()))
390}
391
392fn doc_len_for_domain(doc_data: &DocData, domain: TermDomain) -> f64 {
393 if domain.is_prefix() {
394 return 0.0;
397 }
398
399 let len = doc_data.domain_doc_len.get(domain);
400 if len > 0 {
401 len as f64
402 } else {
403 doc_data.doc_len as f64
404 }
405}
406
407fn average_doc_len(
408 index: &InMemoryIndex,
409 index_name: &str,
410 domain: TermDomain,
411 doc_count: usize,
412) -> f64 {
413 if domain.is_prefix() || doc_count == 0 {
414 return 0.0;
415 }
416
417 let total = index
418 .domain_total_lens
419 .get(index_name)
420 .map(|m| m.get(domain))
421 .unwrap_or(0);
422 if total <= 0 {
423 0.0
424 } else {
425 total as f64 / doc_count as f64
426 }
427}