Skip to main content

memory_indexer/
types.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::RwLock;
3
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::tokenizer::{DictionaryConfig, OffsetMap, SegmentScript, TokenWithScript};
8
9pub const SNAPSHOT_VERSION: u32 = 5;
10
11pub type TermId = u32;
12pub type DocId = u32;
13
14/// Search execution strategy for a query.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SearchMode {
17    /// Only search original terms.
18    Exact,
19    /// Search derived pinyin domains.
20    Pinyin,
21    /// Allow fuzzy matching for tolerant recall.
22    Fuzzy,
23    /// Try exact first, then pinyin and fuzzy fallbacks.
24    Auto,
25}
26
27/// Token domain representing how a term was derived or transformed.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
29pub enum TermDomain {
30    Original,
31    PinyinFull,
32    PinyinInitials,
33}
34
35pub(crate) const TERM_DOMAIN_COUNT: usize = 3;
36
37pub(crate) const fn domain_index(domain: TermDomain) -> usize {
38    match domain {
39        TermDomain::Original => 0,
40        TermDomain::PinyinFull => 1,
41        TermDomain::PinyinInitials => 2,
42    }
43}
44
45impl TermDomain {
46    /// Returns true if the domain represents a pinyin-derived token.
47    pub fn is_pinyin(&self) -> bool {
48        matches!(self, TermDomain::PinyinFull | TermDomain::PinyinInitials)
49    }
50}
51
52#[derive(Debug, Clone, Copy)]
53pub struct DomainConfig {
54    pub weight: f64,
55    pub allow_fuzzy: bool,
56}
57
58pub fn domain_config(domain: TermDomain) -> DomainConfig {
59    match domain {
60        TermDomain::Original => DomainConfig {
61            weight: 1.0,
62            allow_fuzzy: true,
63        },
64        TermDomain::PinyinFull => DomainConfig {
65            weight: 0.9,
66            allow_fuzzy: true,
67        },
68        TermDomain::PinyinInitials => DomainConfig {
69            weight: 0.8,
70            allow_fuzzy: true,
71        },
72    }
73}
74
75pub fn all_domains() -> &'static [TermDomain] {
76    &[
77        TermDomain::Original,
78        TermDomain::PinyinFull,
79        TermDomain::PinyinInitials,
80    ]
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Posting {
85    pub doc: DocId,
86    pub freq: u32,
87}
88
89#[derive(Debug, Default, Clone)]
90pub struct DomainAux {
91    pub term_ids: Option<Vec<TermId>>,
92    pub ngram_index: Option<HashMap<SmolStr, Vec<TermId>>>,
93    pub prefix_index: Option<HashMap<SmolStr, Vec<TermId>>>,
94}
95
96impl DomainAux {
97    pub fn clear(&mut self) {
98        *self = Self::default();
99    }
100
101    fn default_lock() -> RwLock<Self> {
102        RwLock::new(Self::default())
103    }
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107pub struct DomainIndex {
108    pub postings: HashMap<TermId, Vec<Posting>>,
109    #[serde(skip, default = "DomainAux::default_lock")]
110    pub aux: RwLock<DomainAux>,
111}
112
113impl Clone for DomainIndex {
114    fn clone(&self) -> Self {
115        Self {
116            postings: self.postings.clone(),
117            aux: RwLock::new(DomainAux::default()),
118        }
119    }
120}
121
122impl Default for DomainIndex {
123    fn default() -> Self {
124        Self {
125            postings: HashMap::new(),
126            aux: RwLock::new(DomainAux::default()),
127        }
128    }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct TermPositions {
133    pub term: TermId,
134    pub positions: Vec<(u32, u32)>,
135}
136
137#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
138pub struct TermFrequencyEntry {
139    pub term: TermId,
140    pub counts: [u32; TERM_DOMAIN_COUNT],
141}
142
143impl TermFrequencyEntry {
144    pub fn positive_domains(&self) -> Vec<(TermDomain, u32)> {
145        let mut domains = Vec::new();
146        for domain in all_domains() {
147            let count = self.counts[domain_index(*domain)];
148            if count > 0 {
149                domains.push((*domain, count));
150            }
151        }
152        domains
153    }
154}
155
156#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
157pub struct DerivedTerm {
158    pub derived: TermId,
159    pub base: TermId,
160}
161
162#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
163pub struct DerivedSpan {
164    pub derived: TermId,
165    pub span: (u32, u32),
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct DocData {
170    pub content: String,
171    /// Document length in normalized tokens.
172    pub doc_len: i64,
173    /// Term positions for original-domain tokens (by term id).
174    pub term_pos: Vec<TermPositions>,
175    /// Per-term frequencies for removal and length accounting.
176    pub term_freqs: Vec<TermFrequencyEntry>,
177    #[serde(default)]
178    pub domain_doc_len: DomainLengths,
179    /// Mapping from derived term id -> base term id for highlighting.
180    #[serde(default)]
181    pub derived_terms: Vec<DerivedTerm>,
182    /// Fallback spans for derived terms that lack base term positions.
183    #[serde(default)]
184    pub derived_spans: Vec<DerivedSpan>,
185}
186
187/// In-memory inverted index supporting exact, pinyin, and fuzzy search over documents.
188#[derive(Debug)]
189pub struct InMemoryIndex {
190    pub indexes: HashMap<String, IndexState>,
191    pub position_encoding: PositionEncoding,
192    pub dictionary: Option<DictionaryConfig>,
193}
194
195#[derive(Debug, Default)]
196pub struct IndexState {
197    pub version: u32,
198    pub terms: Vec<SmolStr>,
199    pub term_index: HashMap<SmolStr, TermId>,
200    pub docs: Vec<Option<DocData>>,
201    pub doc_ids: Vec<SmolStr>,
202    pub doc_index: HashMap<SmolStr, DocId>,
203    pub free_docs: Vec<DocId>,
204    pub domains: HashMap<TermDomain, DomainIndex>,
205    pub total_len: i64,
206    pub domain_total_len: DomainLengths,
207    pub dirty: HashSet<SmolStr>,
208    pub deleted: HashSet<SmolStr>,
209}
210
211impl InMemoryIndex {
212    pub fn snapshot_version() -> u32 {
213        SNAPSHOT_VERSION
214    }
215}
216
217#[derive(Debug, Clone, Copy)]
218pub struct Segment<'a> {
219    pub script: SegmentScript,
220    pub text: &'a str,
221    pub offset: usize,
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct TokenDraft {
226    pub text: String,
227    pub span: (usize, usize),
228    pub script: SegmentScript,
229    pub mapping: OffsetMap,
230}
231
232impl From<TokenWithScript> for TokenDraft {
233    fn from(value: TokenWithScript) -> Self {
234        Self {
235            text: value.term,
236            span: (value.start, value.end),
237            script: value.script,
238            mapping: value.offset_map,
239        }
240    }
241}
242
243#[derive(Debug, Clone, PartialEq, Eq)]
244pub struct NormalizedTerm {
245    pub term: String,
246    pub span: (usize, usize),
247    pub script: SegmentScript,
248    pub mapping: OffsetMap,
249}
250
251#[derive(Debug, Clone, PartialEq, Eq)]
252pub struct PipelineToken {
253    pub term: String,
254    pub span: (usize, usize),
255    pub domain: TermDomain,
256    pub base_term: String,
257}
258
259pub struct TokenStream {
260    pub tokens: Vec<PipelineToken>,
261    pub doc_len: i64,
262}
263
264/// Persisted index state including documents and per-domain postings.
265#[derive(Debug, Serialize, Deserialize)]
266pub struct SnapshotData {
267    #[serde(default)]
268    pub version: u32,
269    pub terms: Vec<SmolStr>,
270    pub docs: Vec<Option<DocData>>,
271    pub doc_ids: Vec<SmolStr>,
272    pub domains: HashMap<TermDomain, DomainIndex>,
273    pub total_len: i64,
274    pub domain_total_len: DomainLengths,
275}
276
277/// Term and domain that matched during search.
278#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct MatchedTerm {
280    pub term: String,
281    pub domain: TermDomain,
282}
283
284impl MatchedTerm {
285    pub fn new(term: String, domain: TermDomain) -> Self {
286        Self { term, domain }
287    }
288}
289
290/// Encoding used when returning match spans.
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum PositionEncoding {
293    /// Return offsets in raw bytes.
294    Bytes,
295    /// Return offsets in UTF-16 code units (useful for JS/DOM).
296    Utf16,
297}
298
299impl Default for InMemoryIndex {
300    fn default() -> Self {
301        Self {
302            indexes: HashMap::new(),
303            position_encoding: PositionEncoding::Utf16,
304            dictionary: None,
305        }
306    }
307}
308
309/// Search hit containing the doc id, score, and matched terms/domains.
310#[derive(Debug, Clone)]
311pub struct SearchHit {
312    pub doc_id: String,
313    pub score: f64,
314    pub matched_terms: Vec<MatchedTerm>,
315}
316
317#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
318pub struct DomainLengths {
319    lens: [i64; TERM_DOMAIN_COUNT],
320}
321
322impl Default for DomainLengths {
323    fn default() -> Self {
324        Self {
325            lens: [0; TERM_DOMAIN_COUNT],
326        }
327    }
328}
329
330impl DomainLengths {
331    pub fn get(&self, domain: TermDomain) -> i64 {
332        self.lens[domain_index(domain)]
333    }
334
335    pub fn clear(&mut self) {
336        self.lens = [0; TERM_DOMAIN_COUNT];
337    }
338
339    pub fn add(&mut self, domain: TermDomain, delta: i64) {
340        let idx = domain_index(domain);
341        self.lens[idx] += delta;
342    }
343
344    pub fn is_zero(&self) -> bool {
345        self.lens.iter().all(|&v| v == 0)
346    }
347
348    pub fn for_each_nonzero(&self, mut f: impl FnMut(TermDomain, i64)) {
349        for domain in all_domains() {
350            let len = self.get(*domain);
351            if len != 0 {
352                f(*domain, len);
353            }
354        }
355    }
356
357    pub fn from_term_freqs(freqs: &[TermFrequencyEntry]) -> Self {
358        let mut lengths = Self::default();
359        for entry in freqs {
360            for (domain, count) in entry.positive_domains() {
361                lengths.add(domain, count as i64);
362            }
363        }
364        lengths
365    }
366
367    pub fn from_doc(doc: &DocData) -> Self {
368        if !doc.domain_doc_len.is_zero() {
369            return doc.domain_doc_len;
370        }
371        let mut lengths = Self::from_term_freqs(&doc.term_freqs);
372        if lengths.is_zero() {
373            lengths.add(TermDomain::Original, doc.doc_len);
374        }
375        lengths
376    }
377}