1use std::{
2 cmp::Ordering,
3 collections::{HashMap, HashSet},
4};
5
6use smol_str::SmolStr;
7
8use super::{
9 super::{
10 ngram::{build_ngram_index, should_index_in_original_aux},
11 tokenizer::Token,
12 types::{
13 DocData, DocId, InMemoryIndex, IndexState, Posting, SearchMode, TermDomain, TermId,
14 domain_config,
15 },
16 },
17 MatchedTerm, SearchHit,
18 scoring::{
19 MIN_SHOULD_MATCH_RATIO, bm25_component, compute_min_should_match, has_minimum_should_match,
20 score_fuzzy_terms,
21 },
22};
23
24const PINYIN_FULL_PREFIX_MIN: usize = 2;
25const PINYIN_INITIALS_PREFIX_MIN: usize = 1;
26const PINYIN_PREFIX_MAX: usize = 16;
27
28struct TermView<'a> {
29 term_id: TermId,
30 term_text: String,
31 postings: &'a [Posting],
32 weight: f64,
33 domain: TermDomain,
34}
35
36impl InMemoryIndex {
37 pub fn search(&self, index_name: &str, query: &str) -> Vec<(String, f64)> {
39 self.search_with_mode_hits(index_name, query, SearchMode::Auto)
40 .into_iter()
41 .map(|hit| (hit.doc_id, hit.score))
42 .collect()
43 }
44
45 pub fn search_hits(&self, index_name: &str, query: &str) -> Vec<SearchHit> {
47 self.search_with_mode_hits(index_name, query, SearchMode::Auto)
48 }
49
50 pub fn search_with_mode(
52 &self,
53 index_name: &str,
54 query: &str,
55 mode: SearchMode,
56 ) -> Vec<(String, f64)> {
57 self.search_with_mode_hits(index_name, query, mode)
58 .into_iter()
59 .map(|hit| (hit.doc_id, hit.score))
60 .collect()
61 }
62
63 pub fn search_with_mode_hits(
65 &self,
66 index_name: &str,
67 query: &str,
68 mode: SearchMode,
69 ) -> Vec<SearchHit> {
70 if query == "*" || query.is_empty() {
71 if let Some(state) = self.indexes.get(index_name) {
72 return state
73 .doc_index
74 .keys()
75 .map(|doc_id| SearchHit {
76 doc_id: doc_id.to_string(),
77 score: 1.0,
78 matched_terms: Vec::new(),
79 })
80 .collect();
81 }
82 return vec![];
83 }
84
85 let query_terms = self.tokenize_query(query);
86 if query_terms.is_empty() {
87 return vec![];
88 }
89
90 match mode {
91 SearchMode::Exact => self.bm25_search(index_name, &query_terms, TermDomain::Original),
92 SearchMode::Pinyin => self.pinyin_search(index_name, &query_terms),
93 SearchMode::Fuzzy => self.fuzzy_search(index_name, &query_terms),
94 SearchMode::Auto => {
95 let exact = self.bm25_search(index_name, &query_terms, TermDomain::Original);
96 if has_minimum_should_match(&exact, query_terms.len()) {
97 return exact;
100 }
101
102 if !is_ascii_alphanumeric_query(&query_terms) {
103 return self.fuzzy_search_internal(index_name, &query_terms, true);
104 }
105
106 let pinyin_prefix = self.pinyin_prefix_search(index_name, &query_terms);
107 if has_minimum_should_match(&pinyin_prefix, query_terms.len()) {
108 return pinyin_prefix;
109 }
110
111 let pinyin_exact = self.pinyin_exact_search(index_name, &query_terms);
112 if has_minimum_should_match(&pinyin_exact, query_terms.len()) {
113 return pinyin_exact;
114 }
115
116 if is_ascii_alphanumeric_query(&query_terms) {
117 let fuzzy_original = self.fuzzy_search(index_name, &query_terms);
118 if !fuzzy_original.is_empty() {
119 return fuzzy_original;
120 }
121 } else {
122 let cjk_fuzzy = self.fuzzy_search_internal(index_name, &query_terms, true);
123 if !cjk_fuzzy.is_empty() {
124 return cjk_fuzzy;
125 }
126 }
127
128 self.fuzzy_pinyin_search(index_name, &query_terms)
129 }
130 }
131 }
132
133 fn bm25_search(
134 &self,
135 index_name: &str,
136 query_terms: &[Token],
137 domain: TermDomain,
138 ) -> Vec<SearchHit> {
139 if query_terms.is_empty() {
140 return vec![];
141 }
142
143 let state = match self.indexes.get(index_name) {
144 Some(state) => state,
145 None => return vec![],
146 };
147
148 let domain_index = match state.domains.get(&domain) {
149 Some(idx) => idx,
150 None => return vec![],
151 };
152
153 let doc_count = state.doc_index.len();
154 if doc_count == 0 {
155 return vec![];
156 }
157
158 let mut term_views: Vec<TermView<'_>> = Vec::new();
159 let weight = domain_config(domain).weight;
160
161 for token in query_terms {
162 let Some(&term_id) = state.term_index.get(token.term.as_str()) else {
163 continue;
164 };
165 let Some(postings) = domain_index.postings.get(&term_id) else {
166 continue;
167 };
168 if postings.is_empty() {
169 continue;
170 }
171 let term_text = state
172 .terms
173 .get(term_id as usize)
174 .map(|term| term.as_str().to_string())
175 .unwrap_or_else(|| token.term.clone());
176 term_views.push(TermView {
177 term_id,
178 term_text,
179 postings,
180 weight,
181 domain,
182 });
183 }
184
185 if term_views.is_empty() {
186 return vec![];
187 }
188
189 let min_should_match =
190 compute_min_should_match(query_terms.len(), term_views.len(), MIN_SHOULD_MATCH_RATIO);
191
192 let n = doc_count as f64;
193 let avgdl = average_doc_len(state, domain, doc_count);
194
195 let mut idfs = HashMap::new();
196 for view in &term_views {
197 let n_q = view.postings.len() as f64;
198 let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
199 idfs.insert(view.term_id, idf);
200 }
201
202 let mut matches: HashMap<DocId, HashSet<MatchedTerm>> = HashMap::new();
203 let mut doc_scores: HashMap<DocId, f64> = HashMap::new();
204 for view in &term_views {
205 let idf = *idfs.get(&view.term_id).unwrap_or(&0.0);
206 for posting in view.postings {
207 let Some(doc_data) = state
208 .docs
209 .get(posting.doc as usize)
210 .and_then(|doc| doc.as_ref())
211 else {
212 continue;
213 };
214 let component = bm25_component(
215 posting.freq as f64,
216 doc_len_for_domain(doc_data, view.domain),
217 avgdl,
218 idf,
219 ) * view.weight;
220 if component > 0.0 {
221 *doc_scores.entry(posting.doc).or_default() += component;
222 matches
223 .entry(posting.doc)
224 .or_default()
225 .insert(MatchedTerm::new(view.term_text.clone(), view.domain));
226 }
227 }
228 }
229
230 let mut scores: Vec<(DocId, f64)> = doc_scores
231 .into_iter()
232 .filter(|(doc_id, _)| {
233 matches
234 .get(doc_id)
235 .map(|set| set.len() >= min_should_match)
236 .unwrap_or(false)
237 })
238 .collect();
239 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
240 scores
241 .into_iter()
242 .filter_map(|(doc_id, score)| {
243 let doc_name = state.doc_ids.get(doc_id as usize)?.to_string();
244 Some(SearchHit {
245 doc_id: doc_name,
246 score,
247 matched_terms: matches
248 .remove(&doc_id)
249 .map(|s| s.into_iter().collect())
250 .unwrap_or_default(),
251 })
252 })
253 .collect()
254 }
255
256 fn pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
257 if !is_ascii_alphanumeric_query(query_terms) {
258 return vec![];
259 }
260
261 let exact = self.pinyin_exact_search(index_name, query_terms);
262 if !exact.is_empty() {
263 return exact;
264 }
265
266 self.pinyin_prefix_search(index_name, query_terms)
267 }
268
269 fn pinyin_prefix_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
270 let full_prefix = self.prefix_search_in_domain(
271 index_name,
272 query_terms,
273 TermDomain::PinyinFull,
274 PINYIN_FULL_PREFIX_MIN,
275 );
276 if !full_prefix.is_empty() {
277 return full_prefix;
278 }
279
280 self.prefix_search_in_domain(
281 index_name,
282 query_terms,
283 TermDomain::PinyinInitials,
284 PINYIN_INITIALS_PREFIX_MIN,
285 )
286 }
287
288 fn pinyin_exact_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
289 let full = self.bm25_search(index_name, query_terms, TermDomain::PinyinFull);
290 if !full.is_empty() {
291 return full;
292 }
293
294 self.bm25_search(index_name, query_terms, TermDomain::PinyinInitials)
295 }
296
297 fn fuzzy_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
298 self.fuzzy_search_internal(index_name, query_terms, false)
299 }
300
301 fn fuzzy_search_internal(
302 &self,
303 index_name: &str,
304 query_terms: &[Token],
305 allow_non_ascii: bool,
306 ) -> Vec<SearchHit> {
307 self.fuzzy_search_in_domain(
308 index_name,
309 query_terms,
310 TermDomain::Original,
311 allow_non_ascii,
312 )
313 }
314
315 fn fuzzy_pinyin_search(&self, index_name: &str, query_terms: &[Token]) -> Vec<SearchHit> {
316 if query_terms.is_empty() || !is_ascii_alphanumeric_query(query_terms) {
317 return vec![];
318 }
319
320 let full =
321 self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinFull, false);
322 if !full.is_empty() {
323 return full;
324 }
325
326 self.fuzzy_search_in_domain(index_name, query_terms, TermDomain::PinyinInitials, false)
327 }
328
329 fn fuzzy_search_in_domain(
330 &self,
331 index_name: &str,
332 query_terms: &[Token],
333 domain: TermDomain,
334 allow_non_ascii: bool,
335 ) -> Vec<SearchHit> {
336 if query_terms.is_empty() || (!allow_non_ascii && !is_ascii_alphanumeric_query(query_terms))
337 {
338 return vec![];
339 }
340
341 if !domain_config(domain).allow_fuzzy {
342 return vec![];
343 }
344
345 let state = match self.indexes.get(index_name) {
346 Some(state) => state,
347 None => return vec![],
348 };
349
350 let domain_index = match state.domains.get(&domain) {
351 Some(idx) => idx,
352 None => return vec![],
353 };
354
355 let doc_count = state.doc_index.len();
356 if doc_count == 0 {
357 return vec![];
358 }
359
360 {
361 let mut aux = domain_index.aux.write().unwrap();
362 if aux.term_ids.is_none() {
363 let mut ids: Vec<TermId> = domain_index
364 .postings
365 .keys()
366 .copied()
367 .filter(|term_id| {
368 if domain == TermDomain::Original {
369 state
370 .terms
371 .get(*term_id as usize)
372 .map(|term| should_index_in_original_aux(term.as_str()))
373 .unwrap_or(false)
374 } else {
375 true
376 }
377 })
378 .collect();
379 ids.sort_unstable();
380 aux.term_ids = Some(ids);
381 }
382 if aux.ngram_index.is_none() {
383 let ids = aux.term_ids.as_ref().unwrap();
384 aux.ngram_index = Some(build_ngram_index(ids, &state.terms));
385 }
386 }
387 let aux = domain_index.aux.read().unwrap();
388 let term_ids = aux.term_ids.as_ref().unwrap();
389 let ngram_index = aux.ngram_index.as_ref().unwrap();
390
391 let n = doc_count as f64;
392 let avgdl = average_doc_len(state, domain, doc_count);
393
394 let mut doc_scores: HashMap<DocId, f64> = HashMap::new();
395 let mut matched_terms: HashMap<DocId, HashSet<MatchedTerm>> = HashMap::new();
396 let weight = domain_config(domain).weight;
397 let mut matched_query_tokens: HashMap<DocId, HashSet<usize>> = HashMap::new();
398 let mut tokens_with_candidates: HashSet<usize> = HashSet::new();
399
400 for (idx, token) in query_terms.iter().enumerate() {
401 let exact_term = state.term_index.get(token.term.as_str()).copied();
402 score_fuzzy_terms(
403 &state.docs,
404 domain_index,
405 term_ids,
406 &state.terms,
407 ngram_index,
408 n,
409 avgdl,
410 &mut doc_scores,
411 &mut matched_terms,
412 &mut matched_query_tokens,
413 &mut tokens_with_candidates,
414 domain,
415 weight,
416 &token.term,
417 idx,
418 exact_term,
419 );
420 }
421
422 let available_terms = tokens_with_candidates.len();
423 let min_should_match =
424 compute_min_should_match(query_terms.len(), available_terms, MIN_SHOULD_MATCH_RATIO);
427
428 let mut scores: Vec<(DocId, f64)> = doc_scores
429 .into_iter()
430 .filter(|(doc_id, _)| {
431 matched_query_tokens
432 .get(doc_id)
433 .map(|set| set.len() >= min_should_match)
434 .unwrap_or(false)
435 })
436 .collect();
437 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
438 scores
439 .into_iter()
440 .filter_map(|(doc_id, score)| {
441 let doc_name = state.doc_ids.get(doc_id as usize)?.to_string();
442 Some(SearchHit {
443 matched_terms: matched_terms
444 .remove(&doc_id)
445 .map(|s| s.into_iter().collect())
446 .unwrap_or_default(),
447 doc_id: doc_name,
448 score,
449 })
450 })
451 .collect()
452 }
453
454 fn prefix_search_in_domain(
455 &self,
456 index_name: &str,
457 query_terms: &[Token],
458 domain: TermDomain,
459 min_prefix_len: usize,
460 ) -> Vec<SearchHit> {
461 if query_terms.is_empty() || !is_ascii_alphanumeric_query(query_terms) {
462 return vec![];
463 }
464
465 let state = match self.indexes.get(index_name) {
466 Some(state) => state,
467 None => return vec![],
468 };
469
470 let domain_index = match state.domains.get(&domain) {
471 Some(idx) => idx,
472 None => return vec![],
473 };
474
475 let doc_count = state.doc_index.len();
476 if doc_count == 0 {
477 return vec![];
478 }
479
480 {
481 let mut aux = domain_index.aux.write().unwrap();
482 if aux.term_ids.is_none() {
483 let mut ids: Vec<TermId> = domain_index.postings.keys().copied().collect();
484 ids.sort_unstable();
485 aux.term_ids = Some(ids);
486 }
487 if aux.prefix_index.is_none() {
488 let mut prefix_index: HashMap<SmolStr, Vec<TermId>> = HashMap::new();
489 let ids = aux.term_ids.as_ref().unwrap();
490 for &term_id in ids {
491 let Some(term) = state.terms.get(term_id as usize) else {
492 continue;
493 };
494 if !term.as_str().is_ascii() {
495 continue;
496 }
497 let term_len = term.len();
498 if term_len < min_prefix_len {
499 continue;
500 }
501 let max = PINYIN_PREFIX_MAX.min(term_len);
502 for len in min_prefix_len..=max {
503 let prefix = SmolStr::new(&term.as_str()[..len]);
504 prefix_index.entry(prefix).or_default().push(term_id);
505 }
506 }
507 aux.prefix_index = Some(prefix_index);
508 }
509 }
510 let aux = domain_index.aux.read().unwrap();
511 let prefix_index = aux.prefix_index.as_ref().unwrap();
512
513 let n = doc_count as f64;
514 let avgdl = average_doc_len(state, domain, doc_count);
515
516 let mut doc_scores: HashMap<DocId, f64> = HashMap::new();
517 let mut matched_terms: HashMap<DocId, HashSet<MatchedTerm>> = HashMap::new();
518 let weight = domain_config(domain).weight;
519 let mut matched_query_tokens: HashMap<DocId, HashSet<usize>> = HashMap::new();
520 let mut tokens_with_candidates: HashSet<usize> = HashSet::new();
521
522 for (idx, token) in query_terms.iter().enumerate() {
523 if token.term.len() < min_prefix_len || token.term.len() > PINYIN_PREFIX_MAX {
524 continue;
525 }
526
527 let Some(candidates) = prefix_index.get(token.term.as_str()) else {
528 continue;
529 };
530 if candidates.is_empty() {
531 continue;
532 }
533
534 tokens_with_candidates.insert(idx);
535
536 for &candidate in candidates {
537 let Some(postings) = domain_index.postings.get(&candidate) else {
538 continue;
539 };
540 if postings.is_empty() {
541 continue;
542 }
543
544 let n_q = postings.len() as f64;
545 let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
546 let candidate_text = state
547 .terms
548 .get(candidate as usize)
549 .map(|term| term.as_str().to_string())
550 .unwrap_or_else(|| token.term.clone());
551
552 for posting in postings {
553 let Some(doc_data) = state
554 .docs
555 .get(posting.doc as usize)
556 .and_then(|doc| doc.as_ref())
557 else {
558 continue;
559 };
560 let term_score = bm25_component(
561 posting.freq as f64,
562 doc_len_for_domain(doc_data, domain),
563 avgdl,
564 idf,
565 ) * weight;
566 if term_score > 0.0 {
567 *doc_scores.entry(posting.doc).or_default() += term_score;
568 matched_terms
569 .entry(posting.doc)
570 .or_default()
571 .insert(MatchedTerm::new(candidate_text.clone(), domain));
572 matched_query_tokens
573 .entry(posting.doc)
574 .or_default()
575 .insert(idx);
576 }
577 }
578 }
579 }
580
581 let available_terms = tokens_with_candidates.len();
582 let min_should_match =
583 compute_min_should_match(query_terms.len(), available_terms, MIN_SHOULD_MATCH_RATIO);
584
585 let mut scores: Vec<(DocId, f64)> = doc_scores
586 .into_iter()
587 .filter(|(doc_id, _)| {
588 matched_query_tokens
589 .get(doc_id)
590 .map(|set| set.len() >= min_should_match)
591 .unwrap_or(false)
592 })
593 .collect();
594 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
595 scores
596 .into_iter()
597 .filter_map(|(doc_id, score)| {
598 let doc_name = state.doc_ids.get(doc_id as usize)?.to_string();
599 Some(SearchHit {
600 matched_terms: matched_terms
601 .remove(&doc_id)
602 .map(|s| s.into_iter().collect())
603 .unwrap_or_default(),
604 doc_id: doc_name,
605 score,
606 })
607 })
608 .collect()
609 }
610}
611
612pub(super) fn is_ascii_alphanumeric_query(tokens: &[Token]) -> bool {
613 tokens
614 .iter()
615 .all(|token| token.term.chars().all(|c| c.is_ascii_alphanumeric()))
616}
617
618fn doc_len_for_domain(doc_data: &DocData, domain: TermDomain) -> f64 {
619 let len = doc_data.domain_doc_len.get(domain);
620 if len > 0 {
621 len as f64
622 } else {
623 doc_data.doc_len as f64
624 }
625}
626
627fn average_doc_len(state: &IndexState, domain: TermDomain, doc_count: usize) -> f64 {
628 if doc_count == 0 {
629 return 0.0;
630 }
631
632 let total = state.domain_total_len.get(domain);
633 if total <= 0 {
634 0.0
635 } else {
636 total as f64 / doc_count as f64
637 }
638}