1use std::collections::{HashMap, HashSet};
2use std::sync::RwLock;
3
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::tokenizer::{DictionaryConfig, OffsetMap, SegmentScript, TokenWithScript};
8
9pub const SNAPSHOT_VERSION: u32 = 5;
10
11pub type TermId = u32;
12pub type DocId = u32;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SearchMode {
17 Exact,
19 Pinyin,
21 Fuzzy,
23 Auto,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
29pub enum TermDomain {
30 Original,
31 PinyinFull,
32 PinyinInitials,
33}
34
35pub(crate) const TERM_DOMAIN_COUNT: usize = 3;
36
37pub(crate) const fn domain_index(domain: TermDomain) -> usize {
38 match domain {
39 TermDomain::Original => 0,
40 TermDomain::PinyinFull => 1,
41 TermDomain::PinyinInitials => 2,
42 }
43}
44
45impl TermDomain {
46 pub fn is_pinyin(&self) -> bool {
48 matches!(self, TermDomain::PinyinFull | TermDomain::PinyinInitials)
49 }
50}
51
52#[derive(Debug, Clone, Copy)]
53pub struct DomainConfig {
54 pub weight: f64,
55 pub allow_fuzzy: bool,
56}
57
58pub fn domain_config(domain: TermDomain) -> DomainConfig {
59 match domain {
60 TermDomain::Original => DomainConfig {
61 weight: 1.0,
62 allow_fuzzy: true,
63 },
64 TermDomain::PinyinFull => DomainConfig {
65 weight: 0.9,
66 allow_fuzzy: true,
67 },
68 TermDomain::PinyinInitials => DomainConfig {
69 weight: 0.8,
70 allow_fuzzy: true,
71 },
72 }
73}
74
75pub fn all_domains() -> &'static [TermDomain] {
76 &[
77 TermDomain::Original,
78 TermDomain::PinyinFull,
79 TermDomain::PinyinInitials,
80 ]
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Posting {
85 pub doc: DocId,
86 pub freq: u32,
87}
88
89#[derive(Debug, Default, Clone)]
90pub struct DomainAux {
91 pub term_ids: Option<Vec<TermId>>,
92 pub ngram_index: Option<HashMap<SmolStr, Vec<TermId>>>,
93 pub prefix_index: Option<HashMap<SmolStr, Vec<TermId>>>,
94}
95
96impl DomainAux {
97 pub fn clear(&mut self) {
98 *self = Self::default();
99 }
100
101 fn default_lock() -> RwLock<Self> {
102 RwLock::new(Self::default())
103 }
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107pub struct DomainIndex {
108 pub postings: HashMap<TermId, Vec<Posting>>,
109 #[serde(skip, default = "DomainAux::default_lock")]
110 pub aux: RwLock<DomainAux>,
111}
112
113impl Clone for DomainIndex {
114 fn clone(&self) -> Self {
115 Self {
116 postings: self.postings.clone(),
117 aux: RwLock::new(DomainAux::default()),
118 }
119 }
120}
121
122impl Default for DomainIndex {
123 fn default() -> Self {
124 Self {
125 postings: HashMap::new(),
126 aux: RwLock::new(DomainAux::default()),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct TermPositions {
133 pub term: TermId,
134 pub positions: Vec<(u32, u32)>,
135}
136
137#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
138pub struct TermFrequencyEntry {
139 pub term: TermId,
140 pub counts: [u32; TERM_DOMAIN_COUNT],
141}
142
143impl TermFrequencyEntry {
144 pub fn positive_domains(&self) -> Vec<(TermDomain, u32)> {
145 let mut domains = Vec::new();
146 for domain in all_domains() {
147 let count = self.counts[domain_index(*domain)];
148 if count > 0 {
149 domains.push((*domain, count));
150 }
151 }
152 domains
153 }
154}
155
156#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
157pub struct DerivedTerm {
158 pub derived: TermId,
159 pub base: TermId,
160}
161
162#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
163pub struct DerivedSpan {
164 pub derived: TermId,
165 pub span: (u32, u32),
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct DocData {
170 pub content: String,
171 pub doc_len: i64,
173 pub term_pos: Vec<TermPositions>,
175 pub term_freqs: Vec<TermFrequencyEntry>,
177 #[serde(default)]
178 pub domain_doc_len: DomainLengths,
179 #[serde(default)]
181 pub derived_terms: Vec<DerivedTerm>,
182 #[serde(default)]
184 pub derived_spans: Vec<DerivedSpan>,
185}
186
187#[derive(Debug)]
189pub struct InMemoryIndex {
190 pub indexes: HashMap<String, IndexState>,
191 pub position_encoding: PositionEncoding,
192 pub dictionary: Option<DictionaryConfig>,
193}
194
195#[derive(Debug, Default)]
196pub struct IndexState {
197 pub version: u32,
198 pub terms: Vec<SmolStr>,
199 pub term_index: HashMap<SmolStr, TermId>,
200 pub docs: Vec<Option<DocData>>,
201 pub doc_ids: Vec<SmolStr>,
202 pub doc_index: HashMap<SmolStr, DocId>,
203 pub free_docs: Vec<DocId>,
204 pub domains: HashMap<TermDomain, DomainIndex>,
205 pub total_len: i64,
206 pub domain_total_len: DomainLengths,
207 pub dirty: HashSet<SmolStr>,
208 pub deleted: HashSet<SmolStr>,
209}
210
211impl InMemoryIndex {
212 pub fn snapshot_version() -> u32 {
213 SNAPSHOT_VERSION
214 }
215}
216
217#[derive(Debug, Clone, Copy)]
218pub struct Segment<'a> {
219 pub script: SegmentScript,
220 pub text: &'a str,
221 pub offset: usize,
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct TokenDraft {
226 pub text: String,
227 pub span: (usize, usize),
228 pub script: SegmentScript,
229 pub mapping: OffsetMap,
230}
231
232impl From<TokenWithScript> for TokenDraft {
233 fn from(value: TokenWithScript) -> Self {
234 Self {
235 text: value.term,
236 span: (value.start, value.end),
237 script: value.script,
238 mapping: value.offset_map,
239 }
240 }
241}
242
243#[derive(Debug, Clone, PartialEq, Eq)]
244pub struct NormalizedTerm {
245 pub term: String,
246 pub span: (usize, usize),
247 pub script: SegmentScript,
248 pub mapping: OffsetMap,
249}
250
251#[derive(Debug, Clone, PartialEq, Eq)]
252pub struct PipelineToken {
253 pub term: String,
254 pub span: (usize, usize),
255 pub domain: TermDomain,
256 pub base_term: String,
257}
258
259pub struct TokenStream {
260 pub tokens: Vec<PipelineToken>,
261 pub doc_len: i64,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
266pub struct SnapshotData {
267 #[serde(default)]
268 pub version: u32,
269 pub terms: Vec<SmolStr>,
270 pub docs: Vec<Option<DocData>>,
271 pub doc_ids: Vec<SmolStr>,
272 pub domains: HashMap<TermDomain, DomainIndex>,
273 pub total_len: i64,
274 pub domain_total_len: DomainLengths,
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct MatchedTerm {
280 pub term: String,
281 pub domain: TermDomain,
282}
283
284impl MatchedTerm {
285 pub fn new(term: String, domain: TermDomain) -> Self {
286 Self { term, domain }
287 }
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum PositionEncoding {
293 Bytes,
295 Utf16,
297}
298
299impl Default for InMemoryIndex {
300 fn default() -> Self {
301 Self {
302 indexes: HashMap::new(),
303 position_encoding: PositionEncoding::Utf16,
304 dictionary: None,
305 }
306 }
307}
308
309#[derive(Debug, Clone)]
311pub struct SearchHit {
312 pub doc_id: String,
313 pub score: f64,
314 pub matched_terms: Vec<MatchedTerm>,
315}
316
317#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
318pub struct DomainLengths {
319 lens: [i64; TERM_DOMAIN_COUNT],
320}
321
322impl Default for DomainLengths {
323 fn default() -> Self {
324 Self {
325 lens: [0; TERM_DOMAIN_COUNT],
326 }
327 }
328}
329
330impl DomainLengths {
331 pub fn get(&self, domain: TermDomain) -> i64 {
332 self.lens[domain_index(domain)]
333 }
334
335 pub fn clear(&mut self) {
336 self.lens = [0; TERM_DOMAIN_COUNT];
337 }
338
339 pub fn add(&mut self, domain: TermDomain, delta: i64) {
340 let idx = domain_index(domain);
341 self.lens[idx] += delta;
342 }
343
344 pub fn is_zero(&self) -> bool {
345 self.lens.iter().all(|&v| v == 0)
346 }
347
348 pub fn for_each_nonzero(&self, mut f: impl FnMut(TermDomain, i64)) {
349 for domain in all_domains() {
350 let len = self.get(*domain);
351 if len != 0 {
352 f(*domain, len);
353 }
354 }
355 }
356
357 pub fn from_term_freqs(freqs: &[TermFrequencyEntry]) -> Self {
358 let mut lengths = Self::default();
359 for entry in freqs {
360 for (domain, count) in entry.positive_domains() {
361 lengths.add(domain, count as i64);
362 }
363 }
364 lengths
365 }
366
367 pub fn from_doc(doc: &DocData) -> Self {
368 if !doc.domain_doc_len.is_zero() {
369 return doc.domain_doc_len;
370 }
371 let mut lengths = Self::from_term_freqs(&doc.term_freqs);
372 if lengths.is_zero() {
373 lengths.add(TermDomain::Original, doc.doc_len);
374 }
375 lengths
376 }
377}