1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use super::tokenizer::{DictionaryConfig, OffsetMap, SegmentScript, TokenWithScript};
6
7pub const SNAPSHOT_VERSION: u32 = 4;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum SearchMode {
12 Exact,
14 Pinyin,
16 Fuzzy,
18 Auto,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub enum TermDomain {
25 Original,
26 PinyinFull,
27 PinyinInitials,
28 PinyinFullPrefix,
29 PinyinInitialsPrefix,
30}
31
32const TERM_DOMAIN_COUNT: usize = 5;
33
34const fn domain_index(domain: TermDomain) -> usize {
35 match domain {
36 TermDomain::Original => 0,
37 TermDomain::PinyinFull => 1,
38 TermDomain::PinyinInitials => 2,
39 TermDomain::PinyinFullPrefix => 3,
40 TermDomain::PinyinInitialsPrefix => 4,
41 }
42}
43
44impl TermDomain {
45 pub fn is_pinyin(&self) -> bool {
47 matches!(
48 self,
49 TermDomain::PinyinFull
50 | TermDomain::PinyinInitials
51 | TermDomain::PinyinFullPrefix
52 | TermDomain::PinyinInitialsPrefix
53 )
54 }
55
56 pub fn is_prefix(&self) -> bool {
58 matches!(
59 self,
60 TermDomain::PinyinFullPrefix | TermDomain::PinyinInitialsPrefix
61 )
62 }
63}
64
65#[derive(Debug, Clone, Copy)]
66pub struct DomainConfig {
67 pub weight: f64,
68 pub enable_ngrams: bool,
69 pub allow_fuzzy: bool,
70}
71
72pub fn domain_config(domain: TermDomain) -> DomainConfig {
73 match domain {
74 TermDomain::Original => DomainConfig {
75 weight: 1.0,
76 enable_ngrams: true,
77 allow_fuzzy: true,
78 },
79 TermDomain::PinyinFull => DomainConfig {
80 weight: 0.9,
81 enable_ngrams: true,
82 allow_fuzzy: true,
83 },
84 TermDomain::PinyinInitials => DomainConfig {
85 weight: 0.8,
86 enable_ngrams: true,
87 allow_fuzzy: true,
88 },
89 TermDomain::PinyinFullPrefix => DomainConfig {
90 weight: 0.7,
91 enable_ngrams: false,
92 allow_fuzzy: false,
93 },
94 TermDomain::PinyinInitialsPrefix => DomainConfig {
95 weight: 0.75,
96 enable_ngrams: false,
97 allow_fuzzy: false,
98 },
99 }
100}
101
102pub fn all_domains() -> &'static [TermDomain] {
103 &[
104 TermDomain::Original,
105 TermDomain::PinyinFull,
106 TermDomain::PinyinInitials,
107 TermDomain::PinyinFullPrefix,
108 TermDomain::PinyinInitialsPrefix,
109 ]
110}
111
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct DomainIndex {
114 pub postings: HashMap<String, HashMap<String, i64>>,
115 pub term_dict: HashSet<String>,
116 pub ngram_index: HashMap<String, Vec<String>>,
117}
118
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120#[serde(default)]
121pub struct TermFrequency {
122 pub counts: HashMap<TermDomain, u32>,
123}
124
125impl TermFrequency {
126 pub fn increment(&mut self, domain: TermDomain) {
127 *self.counts.entry(domain).or_default() += 1;
128 }
129
130 pub fn get(&self, domain: TermDomain) -> u32 {
131 *self.counts.get(&domain).unwrap_or(&0)
132 }
133
134 pub fn positive_domains(&self) -> Vec<(TermDomain, u32)> {
135 let mut domains = Vec::new();
136 for domain in all_domains() {
137 if let Some(count) = self.counts.get(domain) {
138 if *count > 0 {
139 domains.push((*domain, *count));
140 }
141 }
142 }
143 domains
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct DocData {
149 pub content: String,
150 pub doc_len: i64,
152 pub term_pos: HashMap<String, Vec<(u32, u32)>>,
154 #[serde(default)]
155 pub term_freqs: HashMap<String, TermFrequency>,
156 #[serde(default)]
157 pub domain_doc_len: DomainLengths,
158 #[serde(default)]
159 pub derived_terms: HashMap<String, Vec<(u32, u32)>>,
160}
161
162#[derive(Debug)]
164pub struct InMemoryIndex {
165 pub versions: HashMap<String, u32>,
166 pub docs: HashMap<String, HashMap<String, DocData>>,
167 pub domains: HashMap<String, HashMap<TermDomain, DomainIndex>>,
168 pub total_lens: HashMap<String, i64>,
169 pub domain_total_lens: HashMap<String, DomainLengths>,
170 pub dirty: HashMap<String, HashSet<String>>,
171 pub deleted: HashMap<String, HashSet<String>>,
172 pub position_encoding: PositionEncoding,
173 pub dictionary: Option<DictionaryConfig>,
174}
175
176#[derive(Debug, Clone, Copy)]
177pub struct Segment<'a> {
178 pub script: SegmentScript,
179 pub text: &'a str,
180 pub offset: usize,
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct TokenDraft {
185 pub text: String,
186 pub span: (usize, usize),
187 pub script: SegmentScript,
188 pub mapping: OffsetMap,
189}
190
191impl From<TokenWithScript> for TokenDraft {
192 fn from(value: TokenWithScript) -> Self {
193 Self {
194 text: value.term,
195 span: (value.start, value.end),
196 script: value.script,
197 mapping: value.offset_map,
198 }
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub struct NormalizedTerm {
204 pub term: String,
205 pub span: (usize, usize),
206 pub script: SegmentScript,
207 pub mapping: OffsetMap,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub struct PipelineToken {
212 pub term: String,
213 pub span: (usize, usize),
214 pub domain: TermDomain,
215 pub base_term: String,
216}
217
218pub struct TokenStream {
219 pub tokens: Vec<PipelineToken>,
220 pub term_freqs: HashMap<String, TermFrequency>,
221 pub doc_len: i64,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
227pub struct SnapshotData {
228 #[serde(default)]
229 pub version: u32,
230 pub docs: HashMap<String, DocData>,
231 pub domains: HashMap<TermDomain, DomainIndex>,
232 pub total_len: i64,
233 pub domain_total_len: DomainLengths,
234}
235
236#[derive(Debug, Clone, PartialEq, Eq, Hash)]
238pub struct MatchedTerm {
239 pub term: String,
240 pub domain: TermDomain,
241}
242
243impl MatchedTerm {
244 pub fn new(term: String, domain: TermDomain) -> Self {
245 Self { term, domain }
246 }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251pub enum PositionEncoding {
252 Bytes,
254 Utf16,
256}
257
258impl Default for InMemoryIndex {
259 fn default() -> Self {
260 Self {
261 versions: HashMap::new(),
262 docs: HashMap::new(),
263 domains: HashMap::new(),
264 total_lens: HashMap::new(),
265 domain_total_lens: HashMap::new(),
266 dirty: HashMap::new(),
267 deleted: HashMap::new(),
268 position_encoding: PositionEncoding::Utf16,
269 dictionary: None,
270 }
271 }
272}
273
274#[derive(Debug, Clone)]
276pub struct SearchHit {
277 pub doc_id: String,
278 pub score: f64,
279 pub matched_terms: Vec<MatchedTerm>,
280}
281
282#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
283pub struct DomainLengths {
284 lens: [i64; TERM_DOMAIN_COUNT],
285}
286
287impl Default for DomainLengths {
288 fn default() -> Self {
289 Self {
290 lens: [0; TERM_DOMAIN_COUNT],
291 }
292 }
293}
294
295impl DomainLengths {
296 pub fn get(&self, domain: TermDomain) -> i64 {
297 self.lens[domain_index(domain)]
298 }
299
300 pub fn clear(&mut self) {
301 self.lens = [0; TERM_DOMAIN_COUNT];
302 }
303
304 pub fn add(&mut self, domain: TermDomain, delta: i64) {
305 let idx = domain_index(domain);
306 self.lens[idx] += delta;
307 }
308
309 pub fn is_zero(&self) -> bool {
310 self.lens.iter().all(|&v| v == 0)
311 }
312
313 pub fn for_each_nonzero(&self, mut f: impl FnMut(TermDomain, i64)) {
314 for domain in all_domains() {
315 let len = self.get(*domain);
316 if len != 0 {
317 f(*domain, len);
318 }
319 }
320 }
321
322 pub fn from_term_freqs(freqs: &HashMap<String, TermFrequency>) -> Self {
323 let mut lengths = Self::default();
324 for freqs in freqs.values() {
325 for (domain, count) in freqs.positive_domains() {
326 lengths.add(domain, count as i64);
327 }
328 }
329 lengths
330 }
331
332 pub fn from_doc(doc: &DocData) -> Self {
333 if !doc.domain_doc_len.is_zero() {
334 return doc.domain_doc_len;
335 }
336 let mut lengths = Self::from_term_freqs(&doc.term_freqs);
337 if lengths.is_zero() {
338 lengths.add(TermDomain::Original, doc.doc_len);
339 }
340 lengths
341 }
342}