memory_indexer/
types.rs

1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use super::tokenizer::{DictionaryConfig, OffsetMap, SegmentScript, TokenWithScript};
6
7pub const SNAPSHOT_VERSION: u32 = 4;
8
9/// Search execution strategy for a query.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum SearchMode {
12    /// Only search original terms.
13    Exact,
14    /// Search derived pinyin domains.
15    Pinyin,
16    /// Allow fuzzy matching for tolerant recall.
17    Fuzzy,
18    /// Try exact first, then pinyin and fuzzy fallbacks.
19    Auto,
20}
21
22/// Token domain representing how a term was derived or transformed.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub enum TermDomain {
25    Original,
26    PinyinFull,
27    PinyinInitials,
28    PinyinFullPrefix,
29    PinyinInitialsPrefix,
30}
31
32const TERM_DOMAIN_COUNT: usize = 5;
33
34const fn domain_index(domain: TermDomain) -> usize {
35    match domain {
36        TermDomain::Original => 0,
37        TermDomain::PinyinFull => 1,
38        TermDomain::PinyinInitials => 2,
39        TermDomain::PinyinFullPrefix => 3,
40        TermDomain::PinyinInitialsPrefix => 4,
41    }
42}
43
44impl TermDomain {
45    /// Returns true if the domain represents a pinyin-derived token.
46    pub fn is_pinyin(&self) -> bool {
47        matches!(
48            self,
49            TermDomain::PinyinFull
50                | TermDomain::PinyinInitials
51                | TermDomain::PinyinFullPrefix
52                | TermDomain::PinyinInitialsPrefix
53        )
54    }
55
56    /// Returns true if the domain stores prefix tokens rather than full terms.
57    pub fn is_prefix(&self) -> bool {
58        matches!(
59            self,
60            TermDomain::PinyinFullPrefix | TermDomain::PinyinInitialsPrefix
61        )
62    }
63}
64
65#[derive(Debug, Clone, Copy)]
66pub struct DomainConfig {
67    pub weight: f64,
68    pub enable_ngrams: bool,
69    pub allow_fuzzy: bool,
70}
71
72pub fn domain_config(domain: TermDomain) -> DomainConfig {
73    match domain {
74        TermDomain::Original => DomainConfig {
75            weight: 1.0,
76            enable_ngrams: true,
77            allow_fuzzy: true,
78        },
79        TermDomain::PinyinFull => DomainConfig {
80            weight: 0.9,
81            enable_ngrams: true,
82            allow_fuzzy: true,
83        },
84        TermDomain::PinyinInitials => DomainConfig {
85            weight: 0.8,
86            enable_ngrams: true,
87            allow_fuzzy: true,
88        },
89        TermDomain::PinyinFullPrefix => DomainConfig {
90            weight: 0.7,
91            enable_ngrams: false,
92            allow_fuzzy: false,
93        },
94        TermDomain::PinyinInitialsPrefix => DomainConfig {
95            weight: 0.75,
96            enable_ngrams: false,
97            allow_fuzzy: false,
98        },
99    }
100}
101
102pub fn all_domains() -> &'static [TermDomain] {
103    &[
104        TermDomain::Original,
105        TermDomain::PinyinFull,
106        TermDomain::PinyinInitials,
107        TermDomain::PinyinFullPrefix,
108        TermDomain::PinyinInitialsPrefix,
109    ]
110}
111
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct DomainIndex {
114    pub postings: HashMap<String, HashMap<String, i64>>,
115    pub term_dict: HashSet<String>,
116    pub ngram_index: HashMap<String, Vec<String>>,
117}
118
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120#[serde(default)]
121pub struct TermFrequency {
122    pub counts: HashMap<TermDomain, u32>,
123}
124
125impl TermFrequency {
126    pub fn increment(&mut self, domain: TermDomain) {
127        *self.counts.entry(domain).or_default() += 1;
128    }
129
130    pub fn get(&self, domain: TermDomain) -> u32 {
131        *self.counts.get(&domain).unwrap_or(&0)
132    }
133
134    pub fn positive_domains(&self) -> Vec<(TermDomain, u32)> {
135        let mut domains = Vec::new();
136        for domain in all_domains() {
137            if let Some(count) = self.counts.get(domain) {
138                if *count > 0 {
139                    domains.push((*domain, *count));
140                }
141            }
142        }
143        domains
144    }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct DocData {
149    pub content: String,
150    /// Document length in normalized tokens.
151    pub doc_len: i64,
152    /// Term positions for original-domain tokens.
153    pub term_pos: HashMap<String, Vec<(u32, u32)>>,
154    #[serde(default)]
155    pub term_freqs: HashMap<String, TermFrequency>,
156    #[serde(default)]
157    pub domain_doc_len: DomainLengths,
158    #[serde(default)]
159    pub derived_terms: HashMap<String, Vec<(u32, u32)>>,
160}
161
162/// In-memory inverted index supporting exact, pinyin, and fuzzy search over documents.
163#[derive(Debug)]
164pub struct InMemoryIndex {
165    pub versions: HashMap<String, u32>,
166    pub docs: HashMap<String, HashMap<String, DocData>>,
167    pub domains: HashMap<String, HashMap<TermDomain, DomainIndex>>,
168    pub total_lens: HashMap<String, i64>,
169    pub domain_total_lens: HashMap<String, DomainLengths>,
170    pub dirty: HashMap<String, HashSet<String>>,
171    pub deleted: HashMap<String, HashSet<String>>,
172    pub position_encoding: PositionEncoding,
173    pub dictionary: Option<DictionaryConfig>,
174}
175
176#[derive(Debug, Clone, Copy)]
177pub struct Segment<'a> {
178    pub script: SegmentScript,
179    pub text: &'a str,
180    pub offset: usize,
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct TokenDraft {
185    pub text: String,
186    pub span: (usize, usize),
187    pub script: SegmentScript,
188    pub mapping: OffsetMap,
189}
190
191impl From<TokenWithScript> for TokenDraft {
192    fn from(value: TokenWithScript) -> Self {
193        Self {
194            text: value.term,
195            span: (value.start, value.end),
196            script: value.script,
197            mapping: value.offset_map,
198        }
199    }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub struct NormalizedTerm {
204    pub term: String,
205    pub span: (usize, usize),
206    pub script: SegmentScript,
207    pub mapping: OffsetMap,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub struct PipelineToken {
212    pub term: String,
213    pub span: (usize, usize),
214    pub domain: TermDomain,
215    pub base_term: String,
216}
217
218pub struct TokenStream {
219    pub tokens: Vec<PipelineToken>,
220    pub term_freqs: HashMap<String, TermFrequency>,
221    pub doc_len: i64,
222}
223
224/// Snapshot of per-domain auxiliary structures.
225/// Persisted index state including documents and aux domain data.
226#[derive(Debug, Serialize, Deserialize)]
227pub struct SnapshotData {
228    #[serde(default)]
229    pub version: u32,
230    pub docs: HashMap<String, DocData>,
231    pub domains: HashMap<TermDomain, DomainIndex>,
232    pub total_len: i64,
233    pub domain_total_len: DomainLengths,
234}
235
236/// Term and domain that matched during search.
237#[derive(Debug, Clone, PartialEq, Eq, Hash)]
238pub struct MatchedTerm {
239    pub term: String,
240    pub domain: TermDomain,
241}
242
243impl MatchedTerm {
244    pub fn new(term: String, domain: TermDomain) -> Self {
245        Self { term, domain }
246    }
247}
248
249/// Encoding used when returning match spans.
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251pub enum PositionEncoding {
252    /// Return offsets in raw bytes.
253    Bytes,
254    /// Return offsets in UTF-16 code units (useful for JS/DOM).
255    Utf16,
256}
257
258impl Default for InMemoryIndex {
259    fn default() -> Self {
260        Self {
261            versions: HashMap::new(),
262            docs: HashMap::new(),
263            domains: HashMap::new(),
264            total_lens: HashMap::new(),
265            domain_total_lens: HashMap::new(),
266            dirty: HashMap::new(),
267            deleted: HashMap::new(),
268            position_encoding: PositionEncoding::Utf16,
269            dictionary: None,
270        }
271    }
272}
273
274/// Search hit containing the doc id, score, and matched terms/domains.
275#[derive(Debug, Clone)]
276pub struct SearchHit {
277    pub doc_id: String,
278    pub score: f64,
279    pub matched_terms: Vec<MatchedTerm>,
280}
281
282#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
283pub struct DomainLengths {
284    lens: [i64; TERM_DOMAIN_COUNT],
285}
286
287impl Default for DomainLengths {
288    fn default() -> Self {
289        Self {
290            lens: [0; TERM_DOMAIN_COUNT],
291        }
292    }
293}
294
295impl DomainLengths {
296    pub fn get(&self, domain: TermDomain) -> i64 {
297        self.lens[domain_index(domain)]
298    }
299
300    pub fn clear(&mut self) {
301        self.lens = [0; TERM_DOMAIN_COUNT];
302    }
303
304    pub fn add(&mut self, domain: TermDomain, delta: i64) {
305        let idx = domain_index(domain);
306        self.lens[idx] += delta;
307    }
308
309    pub fn is_zero(&self) -> bool {
310        self.lens.iter().all(|&v| v == 0)
311    }
312
313    pub fn for_each_nonzero(&self, mut f: impl FnMut(TermDomain, i64)) {
314        for domain in all_domains() {
315            let len = self.get(*domain);
316            if len != 0 {
317                f(*domain, len);
318            }
319        }
320    }
321
322    pub fn from_term_freqs(freqs: &HashMap<String, TermFrequency>) -> Self {
323        let mut lengths = Self::default();
324        for freqs in freqs.values() {
325            for (domain, count) in freqs.positive_domains() {
326                lengths.add(domain, count as i64);
327            }
328        }
329        lengths
330    }
331
332    pub fn from_doc(doc: &DocData) -> Self {
333        if !doc.domain_doc_len.is_zero() {
334            return doc.domain_doc_len;
335        }
336        let mut lengths = Self::from_term_freqs(&doc.term_freqs);
337        if lengths.is_zero() {
338            lengths.add(TermDomain::Original, doc.doc_len);
339        }
340        lengths
341    }
342}