1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use super::tokenizer::{DictionaryConfig, OffsetMap, SegmentScript, TokenWithScript};
6
7pub const SNAPSHOT_VERSION: u32 = 3;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum SearchMode {
12 Exact,
14 Pinyin,
16 Fuzzy,
18 Auto,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub enum TermDomain {
25 Original,
26 PinyinFull,
27 PinyinInitials,
28 PinyinFullPrefix,
29 PinyinInitialsPrefix,
30}
31
32const TERM_DOMAIN_COUNT: usize = 5;
33
34const fn domain_index(domain: TermDomain) -> usize {
35 match domain {
36 TermDomain::Original => 0,
37 TermDomain::PinyinFull => 1,
38 TermDomain::PinyinInitials => 2,
39 TermDomain::PinyinFullPrefix => 3,
40 TermDomain::PinyinInitialsPrefix => 4,
41 }
42}
43
44impl TermDomain {
45 pub fn is_pinyin(&self) -> bool {
47 matches!(
48 self,
49 TermDomain::PinyinFull
50 | TermDomain::PinyinInitials
51 | TermDomain::PinyinFullPrefix
52 | TermDomain::PinyinInitialsPrefix
53 )
54 }
55
56 pub fn is_prefix(&self) -> bool {
58 matches!(
59 self,
60 TermDomain::PinyinFullPrefix | TermDomain::PinyinInitialsPrefix
61 )
62 }
63}
64
65#[derive(Debug, Clone, Copy)]
66pub struct DomainConfig {
67 pub weight: f64,
68 pub enable_ngrams: bool,
69 pub allow_fuzzy: bool,
70}
71
72pub fn domain_config(domain: TermDomain) -> DomainConfig {
73 match domain {
74 TermDomain::Original => DomainConfig {
75 weight: 1.0,
76 enable_ngrams: true,
77 allow_fuzzy: true,
78 },
79 TermDomain::PinyinFull => DomainConfig {
80 weight: 0.9,
81 enable_ngrams: true,
82 allow_fuzzy: true,
83 },
84 TermDomain::PinyinInitials => DomainConfig {
85 weight: 0.8,
86 enable_ngrams: true,
87 allow_fuzzy: true,
88 },
89 TermDomain::PinyinFullPrefix => DomainConfig {
90 weight: 0.7,
91 enable_ngrams: false,
92 allow_fuzzy: false,
93 },
94 TermDomain::PinyinInitialsPrefix => DomainConfig {
95 weight: 0.75,
96 enable_ngrams: false,
97 allow_fuzzy: false,
98 },
99 }
100}
101
102pub fn all_domains() -> &'static [TermDomain] {
103 &[
104 TermDomain::Original,
105 TermDomain::PinyinFull,
106 TermDomain::PinyinInitials,
107 TermDomain::PinyinFullPrefix,
108 TermDomain::PinyinInitialsPrefix,
109 ]
110}
111
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct DomainIndex {
114 pub postings: HashMap<String, HashMap<String, i64>>,
115 pub term_dict: HashSet<String>,
116 pub ngram_index: HashMap<String, Vec<String>>,
117}
118
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120#[serde(default)]
121pub struct TermFrequency {
122 pub counts: HashMap<TermDomain, u32>,
123}
124
125impl TermFrequency {
126 pub fn increment(&mut self, domain: TermDomain) {
127 *self.counts.entry(domain).or_default() += 1;
128 }
129
130 pub fn get(&self, domain: TermDomain) -> u32 {
131 *self.counts.get(&domain).unwrap_or(&0)
132 }
133
134 pub fn positive_domains(&self) -> Vec<(TermDomain, u32)> {
135 let mut domains = Vec::new();
136 for domain in all_domains() {
137 if let Some(count) = self.counts.get(domain) {
138 if *count > 0 {
139 domains.push((*domain, *count));
140 }
141 }
142 }
143 domains
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct DocData {
149 pub content: String,
150 pub doc_len: i64,
152 pub term_pos: HashMap<String, Vec<(u32, u32)>>,
154 #[serde(default)]
155 pub term_freqs: HashMap<String, TermFrequency>,
156 #[serde(default)]
157 pub domain_doc_len: DomainLengths,
158 #[serde(default)]
159 pub derived_terms: HashMap<String, Vec<(u32, u32)>>,
160}
161
162#[derive(Debug)]
164pub struct InMemoryIndex {
165 pub versions: HashMap<String, u32>,
166 pub docs: HashMap<String, HashMap<String, DocData>>,
167 pub domains: HashMap<String, HashMap<TermDomain, DomainIndex>>,
168 pub total_lens: HashMap<String, i64>,
169 pub domain_total_lens: HashMap<String, DomainLengths>,
170 pub dirty: HashMap<String, HashSet<String>>,
171 pub deleted: HashMap<String, HashSet<String>>,
172 pub position_encoding: PositionEncoding,
173 pub dictionary: Option<DictionaryConfig>,
174}
175
176#[derive(Debug, Clone, Copy)]
177pub struct Segment<'a> {
178 pub script: SegmentScript,
179 pub text: &'a str,
180 pub offset: usize,
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct TokenDraft {
185 pub text: String,
186 pub span: (usize, usize),
187 pub script: SegmentScript,
188 pub mapping: OffsetMap,
189}
190
191impl From<TokenWithScript> for TokenDraft {
192 fn from(value: TokenWithScript) -> Self {
193 Self {
194 text: value.term,
195 span: (value.start, value.end),
196 script: value.script,
197 mapping: value.offset_map,
198 }
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub struct NormalizedTerm {
204 pub term: String,
205 pub span: (usize, usize),
206 pub script: SegmentScript,
207 pub mapping: OffsetMap,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub struct PipelineToken {
212 pub term: String,
213 pub span: (usize, usize),
214 pub domain: TermDomain,
215 pub base_term: String,
216}
217
218pub struct TokenStream {
219 pub tokens: Vec<PipelineToken>,
220 pub term_freqs: HashMap<String, TermFrequency>,
221 pub doc_len: i64,
222}
223
224#[derive(Debug, Clone, Default, Serialize, Deserialize)]
226pub struct DomainSnapshot {
227 pub term_dict: HashSet<String>,
228 pub ngram_index: HashMap<String, Vec<String>>,
229}
230
231#[derive(Debug, Serialize, Deserialize)]
233pub struct SnapshotData {
234 #[serde(default)]
235 pub version: u32,
236 pub docs: HashMap<String, DocData>,
237 #[serde(default)]
238 pub domains: HashMap<TermDomain, DomainSnapshot>,
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Hash)]
243pub struct MatchedTerm {
244 pub term: String,
245 pub domain: TermDomain,
246}
247
248impl MatchedTerm {
249 pub fn new(term: String, domain: TermDomain) -> Self {
250 Self { term, domain }
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq)]
256pub enum PositionEncoding {
257 Bytes,
259 Utf16,
261}
262
263impl Default for InMemoryIndex {
264 fn default() -> Self {
265 Self {
266 versions: HashMap::new(),
267 docs: HashMap::new(),
268 domains: HashMap::new(),
269 total_lens: HashMap::new(),
270 domain_total_lens: HashMap::new(),
271 dirty: HashMap::new(),
272 deleted: HashMap::new(),
273 position_encoding: PositionEncoding::Utf16,
274 dictionary: None,
275 }
276 }
277}
278
279#[derive(Debug, Clone)]
281pub struct SearchHit {
282 pub doc_id: String,
283 pub score: f64,
284 pub matched_terms: Vec<MatchedTerm>,
285}
286
287#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
288pub struct DomainLengths {
289 lens: [i64; TERM_DOMAIN_COUNT],
290}
291
292impl Default for DomainLengths {
293 fn default() -> Self {
294 Self {
295 lens: [0; TERM_DOMAIN_COUNT],
296 }
297 }
298}
299
300impl DomainLengths {
301 pub fn get(&self, domain: TermDomain) -> i64 {
302 self.lens[domain_index(domain)]
303 }
304
305 pub fn clear(&mut self) {
306 self.lens = [0; TERM_DOMAIN_COUNT];
307 }
308
309 pub fn add(&mut self, domain: TermDomain, delta: i64) {
310 let idx = domain_index(domain);
311 self.lens[idx] += delta;
312 }
313
314 pub fn is_zero(&self) -> bool {
315 self.lens.iter().all(|&v| v == 0)
316 }
317
318 pub fn for_each_nonzero(&self, mut f: impl FnMut(TermDomain, i64)) {
319 for domain in all_domains() {
320 let len = self.get(*domain);
321 if len != 0 {
322 f(*domain, len);
323 }
324 }
325 }
326
327 pub fn from_term_freqs(freqs: &HashMap<String, TermFrequency>) -> Self {
328 let mut lengths = Self::default();
329 for freqs in freqs.values() {
330 for (domain, count) in freqs.positive_domains() {
331 lengths.add(domain, count as i64);
332 }
333 }
334 lengths
335 }
336
337 pub fn from_doc(doc: &DocData) -> Self {
338 if !doc.domain_doc_len.is_zero() {
339 return doc.domain_doc_len;
340 }
341 let mut lengths = Self::from_term_freqs(&doc.term_freqs);
342 if lengths.is_zero() {
343 lengths.add(TermDomain::Original, doc.doc_len);
344 }
345 lengths
346 }
347}