memory_indexer/
types.rs

1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use super::tokenizer::{DictionaryConfig, OffsetMap, SegmentScript, TokenWithScript};
6
7pub const SNAPSHOT_VERSION: u32 = 3;
8
9/// Search execution strategy for a query.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum SearchMode {
12    /// Only search original terms.
13    Exact,
14    /// Search derived pinyin domains.
15    Pinyin,
16    /// Allow fuzzy matching for tolerant recall.
17    Fuzzy,
18    /// Try exact first, then pinyin and fuzzy fallbacks.
19    Auto,
20}
21
22/// Token domain representing how a term was derived or transformed.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub enum TermDomain {
25    Original,
26    PinyinFull,
27    PinyinInitials,
28    PinyinFullPrefix,
29    PinyinInitialsPrefix,
30}
31
32const TERM_DOMAIN_COUNT: usize = 5;
33
34const fn domain_index(domain: TermDomain) -> usize {
35    match domain {
36        TermDomain::Original => 0,
37        TermDomain::PinyinFull => 1,
38        TermDomain::PinyinInitials => 2,
39        TermDomain::PinyinFullPrefix => 3,
40        TermDomain::PinyinInitialsPrefix => 4,
41    }
42}
43
44impl TermDomain {
45    /// Returns true if the domain represents a pinyin-derived token.
46    pub fn is_pinyin(&self) -> bool {
47        matches!(
48            self,
49            TermDomain::PinyinFull
50                | TermDomain::PinyinInitials
51                | TermDomain::PinyinFullPrefix
52                | TermDomain::PinyinInitialsPrefix
53        )
54    }
55
56    /// Returns true if the domain stores prefix tokens rather than full terms.
57    pub fn is_prefix(&self) -> bool {
58        matches!(
59            self,
60            TermDomain::PinyinFullPrefix | TermDomain::PinyinInitialsPrefix
61        )
62    }
63}
64
65#[derive(Debug, Clone, Copy)]
66pub struct DomainConfig {
67    pub weight: f64,
68    pub enable_ngrams: bool,
69    pub allow_fuzzy: bool,
70}
71
72pub fn domain_config(domain: TermDomain) -> DomainConfig {
73    match domain {
74        TermDomain::Original => DomainConfig {
75            weight: 1.0,
76            enable_ngrams: true,
77            allow_fuzzy: true,
78        },
79        TermDomain::PinyinFull => DomainConfig {
80            weight: 0.9,
81            enable_ngrams: true,
82            allow_fuzzy: true,
83        },
84        TermDomain::PinyinInitials => DomainConfig {
85            weight: 0.8,
86            enable_ngrams: true,
87            allow_fuzzy: true,
88        },
89        TermDomain::PinyinFullPrefix => DomainConfig {
90            weight: 0.7,
91            enable_ngrams: false,
92            allow_fuzzy: false,
93        },
94        TermDomain::PinyinInitialsPrefix => DomainConfig {
95            weight: 0.75,
96            enable_ngrams: false,
97            allow_fuzzy: false,
98        },
99    }
100}
101
102pub fn all_domains() -> &'static [TermDomain] {
103    &[
104        TermDomain::Original,
105        TermDomain::PinyinFull,
106        TermDomain::PinyinInitials,
107        TermDomain::PinyinFullPrefix,
108        TermDomain::PinyinInitialsPrefix,
109    ]
110}
111
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct DomainIndex {
114    pub postings: HashMap<String, HashMap<String, i64>>,
115    pub term_dict: HashSet<String>,
116    pub ngram_index: HashMap<String, Vec<String>>,
117}
118
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120#[serde(default)]
121pub struct TermFrequency {
122    pub counts: HashMap<TermDomain, u32>,
123}
124
125impl TermFrequency {
126    pub fn increment(&mut self, domain: TermDomain) {
127        *self.counts.entry(domain).or_default() += 1;
128    }
129
130    pub fn get(&self, domain: TermDomain) -> u32 {
131        *self.counts.get(&domain).unwrap_or(&0)
132    }
133
134    pub fn positive_domains(&self) -> Vec<(TermDomain, u32)> {
135        let mut domains = Vec::new();
136        for domain in all_domains() {
137            if let Some(count) = self.counts.get(domain) {
138                if *count > 0 {
139                    domains.push((*domain, *count));
140                }
141            }
142        }
143        domains
144    }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct DocData {
149    pub content: String,
150    /// Document length in normalized tokens.
151    pub doc_len: i64,
152    /// Term positions for original-domain tokens.
153    pub term_pos: HashMap<String, Vec<(u32, u32)>>,
154    #[serde(default)]
155    pub term_freqs: HashMap<String, TermFrequency>,
156    #[serde(default)]
157    pub domain_doc_len: DomainLengths,
158    #[serde(default)]
159    pub derived_terms: HashMap<String, Vec<(u32, u32)>>,
160}
161
162/// In-memory inverted index supporting exact, pinyin, and fuzzy search over documents.
163#[derive(Debug)]
164pub struct InMemoryIndex {
165    pub versions: HashMap<String, u32>,
166    pub docs: HashMap<String, HashMap<String, DocData>>,
167    pub domains: HashMap<String, HashMap<TermDomain, DomainIndex>>,
168    pub total_lens: HashMap<String, i64>,
169    pub domain_total_lens: HashMap<String, DomainLengths>,
170    pub dirty: HashMap<String, HashSet<String>>,
171    pub deleted: HashMap<String, HashSet<String>>,
172    pub position_encoding: PositionEncoding,
173    pub dictionary: Option<DictionaryConfig>,
174}
175
176#[derive(Debug, Clone, Copy)]
177pub struct Segment<'a> {
178    pub script: SegmentScript,
179    pub text: &'a str,
180    pub offset: usize,
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct TokenDraft {
185    pub text: String,
186    pub span: (usize, usize),
187    pub script: SegmentScript,
188    pub mapping: OffsetMap,
189}
190
191impl From<TokenWithScript> for TokenDraft {
192    fn from(value: TokenWithScript) -> Self {
193        Self {
194            text: value.term,
195            span: (value.start, value.end),
196            script: value.script,
197            mapping: value.offset_map,
198        }
199    }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub struct NormalizedTerm {
204    pub term: String,
205    pub span: (usize, usize),
206    pub script: SegmentScript,
207    pub mapping: OffsetMap,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub struct PipelineToken {
212    pub term: String,
213    pub span: (usize, usize),
214    pub domain: TermDomain,
215    pub base_term: String,
216}
217
218pub struct TokenStream {
219    pub tokens: Vec<PipelineToken>,
220    pub term_freqs: HashMap<String, TermFrequency>,
221    pub doc_len: i64,
222}
223
224/// Snapshot of per-domain auxiliary structures.
225#[derive(Debug, Clone, Default, Serialize, Deserialize)]
226pub struct DomainSnapshot {
227    pub term_dict: HashSet<String>,
228    pub ngram_index: HashMap<String, Vec<String>>,
229}
230
231/// Persisted index state including documents and aux domain data.
232#[derive(Debug, Serialize, Deserialize)]
233pub struct SnapshotData {
234    #[serde(default)]
235    pub version: u32,
236    pub docs: HashMap<String, DocData>,
237    #[serde(default)]
238    pub domains: HashMap<TermDomain, DomainSnapshot>,
239}
240
241/// Term and domain that matched during search.
242#[derive(Debug, Clone, PartialEq, Eq, Hash)]
243pub struct MatchedTerm {
244    pub term: String,
245    pub domain: TermDomain,
246}
247
248impl MatchedTerm {
249    pub fn new(term: String, domain: TermDomain) -> Self {
250        Self { term, domain }
251    }
252}
253
254/// Encoding used when returning match spans.
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
256pub enum PositionEncoding {
257    /// Return offsets in raw bytes.
258    Bytes,
259    /// Return offsets in UTF-16 code units (useful for JS/DOM).
260    Utf16,
261}
262
263impl Default for InMemoryIndex {
264    fn default() -> Self {
265        Self {
266            versions: HashMap::new(),
267            docs: HashMap::new(),
268            domains: HashMap::new(),
269            total_lens: HashMap::new(),
270            domain_total_lens: HashMap::new(),
271            dirty: HashMap::new(),
272            deleted: HashMap::new(),
273            position_encoding: PositionEncoding::Utf16,
274            dictionary: None,
275        }
276    }
277}
278
279/// Search hit containing the doc id, score, and matched terms/domains.
280#[derive(Debug, Clone)]
281pub struct SearchHit {
282    pub doc_id: String,
283    pub score: f64,
284    pub matched_terms: Vec<MatchedTerm>,
285}
286
287#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
288pub struct DomainLengths {
289    lens: [i64; TERM_DOMAIN_COUNT],
290}
291
292impl Default for DomainLengths {
293    fn default() -> Self {
294        Self {
295            lens: [0; TERM_DOMAIN_COUNT],
296        }
297    }
298}
299
300impl DomainLengths {
301    pub fn get(&self, domain: TermDomain) -> i64 {
302        self.lens[domain_index(domain)]
303    }
304
305    pub fn clear(&mut self) {
306        self.lens = [0; TERM_DOMAIN_COUNT];
307    }
308
309    pub fn add(&mut self, domain: TermDomain, delta: i64) {
310        let idx = domain_index(domain);
311        self.lens[idx] += delta;
312    }
313
314    pub fn is_zero(&self) -> bool {
315        self.lens.iter().all(|&v| v == 0)
316    }
317
318    pub fn for_each_nonzero(&self, mut f: impl FnMut(TermDomain, i64)) {
319        for domain in all_domains() {
320            let len = self.get(*domain);
321            if len != 0 {
322                f(*domain, len);
323            }
324        }
325    }
326
327    pub fn from_term_freqs(freqs: &HashMap<String, TermFrequency>) -> Self {
328        let mut lengths = Self::default();
329        for freqs in freqs.values() {
330            for (domain, count) in freqs.positive_domains() {
331                lengths.add(domain, count as i64);
332            }
333        }
334        lengths
335    }
336
337    pub fn from_doc(doc: &DocData) -> Self {
338        if !doc.domain_doc_len.is_zero() {
339            return doc.domain_doc_len;
340        }
341        let mut lengths = Self::from_term_freqs(&doc.term_freqs);
342        if lengths.is_zero() {
343            lengths.add(TermDomain::Original, doc.doc_len);
344        }
345        lengths
346    }
347}