Skip to main content

memory_indexer/
base.rs

1use std::collections::{HashMap, HashSet};
2
3use smol_str::SmolStr;
4
5use super::{
6    SNAPSHOT_VERSION,
7    index::Index,
8    pipeline::{DefaultTokenizer, Pipeline},
9    tokenizer::Token,
10    types::{
11        DerivedSpan, DerivedTerm, DocData, DomainLengths, InMemoryIndex, PositionEncoding,
12        SnapshotData, TermDomain, TermFrequencyEntry, TermId, TermPositions, TokenStream,
13    },
14};
15
16type DirtyDoc = (String, String, String, i64);
17type DeletedDoc = HashMap<String, HashSet<String>>;
18
19impl InMemoryIndex {
20    /// Create an index that returns match spans in the given encoding.
21    pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
22        Self {
23            position_encoding: encoding,
24            ..Default::default()
25        }
26    }
27
28    /// Create an index that uses a custom dictionary for tokenization.
29    pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
30        Self {
31            dictionary: Some(dictionary),
32            ..Default::default()
33        }
34    }
35
36    /// Set the encoding (bytes or UTF-16) used when returning match spans.
37    pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
38        self.position_encoding = encoding;
39    }
40
41    /// Swap in or remove a dictionary config for future tokenization.
42    pub fn set_dictionary_config(
43        &mut self,
44        dictionary: Option<crate::tokenizer::DictionaryConfig>,
45    ) {
46        self.dictionary = dictionary;
47    }
48
49    /// Add or replace a document in an index. Set `index` to false to stage content without
50    /// tokenization (doc will exist but not be searchable).
51    pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
52        let token_stream = if index {
53            self.document_pipeline().document_tokens(text)
54        } else {
55            TokenStream {
56                tokens: Vec::new(),
57                doc_len: 0,
58            }
59        };
60
61        let mut maps = Index {
62            state: self.index_state_mut(index_name),
63        };
64
65        let doc_idx = if let Some(existing) = maps.state.doc_index.get(doc_id) {
66            *existing
67        } else if let Some(reuse) = maps.state.free_docs.pop() {
68            let doc_key = SmolStr::new(doc_id);
69            if let Some(slot) = maps.state.doc_ids.get_mut(reuse as usize) {
70                *slot = doc_key.clone();
71            } else {
72                maps.state
73                    .doc_ids
74                    .resize(reuse as usize + 1, SmolStr::default());
75                maps.state.doc_ids[reuse as usize] = doc_key.clone();
76            }
77            if maps.state.docs.len() <= reuse as usize {
78                maps.state.docs.resize(reuse as usize + 1, None);
79            }
80            maps.state.doc_index.insert(doc_key, reuse);
81            reuse
82        } else {
83            let doc_key = SmolStr::new(doc_id);
84            let id = maps.state.doc_ids.len() as super::types::DocId;
85            maps.state.doc_ids.push(doc_key.clone());
86            maps.state.docs.push(None);
87            maps.state.doc_index.insert(doc_key, id);
88            id
89        };
90
91        if let Some(old_data) = maps
92            .state
93            .docs
94            .get_mut(doc_idx as usize)
95            .and_then(|slot| slot.take())
96        {
97            maps.state.total_len -= old_data.doc_len;
98            let old_domain_lengths = DomainLengths::from_doc(&old_data);
99            old_domain_lengths.for_each_nonzero(|domain, len| {
100                maps.state.domain_total_len.add(domain, -len);
101            });
102            maps.remove_doc_terms(doc_idx, &old_data);
103        }
104
105        let mut term_pos: HashMap<TermId, Vec<(u32, u32)>> = HashMap::new();
106        let mut derived_candidates: Vec<(TermId, TermId, (u32, u32))> = Vec::new();
107        let mut term_freqs: HashMap<TermId, [u32; super::types::TERM_DOMAIN_COUNT]> =
108            HashMap::new();
109
110        for token in &token_stream.tokens {
111            let term_id = get_or_insert_term_id(maps.state, &token.term);
112            let domain_idx = super::types::domain_index(token.domain);
113            let counts = term_freqs
114                .entry(term_id)
115                .or_insert([0; super::types::TERM_DOMAIN_COUNT]);
116            counts[domain_idx] += 1;
117
118            if token.domain == TermDomain::Original {
119                term_pos
120                    .entry(term_id)
121                    .or_default()
122                    .push((token.span.0 as u32, token.span.1 as u32));
123            } else {
124                let base_term_id = get_or_insert_term_id(maps.state, &token.base_term);
125                derived_candidates.push((
126                    term_id,
127                    base_term_id,
128                    (token.span.0 as u32, token.span.1 as u32),
129                ));
130            }
131        }
132
133        let mut term_positions: Vec<TermPositions> = term_pos
134            .into_iter()
135            .map(|(term, mut positions)| {
136                positions.sort();
137                positions.dedup();
138                TermPositions { term, positions }
139            })
140            .collect();
141        term_positions.sort_by_key(|entry| entry.term);
142
143        let base_terms: HashSet<TermId> = term_positions.iter().map(|entry| entry.term).collect();
144        let mut derived_terms: Vec<DerivedTerm> = Vec::new();
145        let mut derived_spans_map: HashMap<TermId, (u32, u32)> = HashMap::new();
146        for (derived, base, span) in derived_candidates {
147            if base_terms.contains(&base) {
148                derived_terms.push(DerivedTerm { derived, base });
149            } else {
150                let span_len = span.1.saturating_sub(span.0);
151                derived_spans_map
152                    .entry(derived)
153                    .and_modify(|existing| {
154                        let existing_len = existing.1.saturating_sub(existing.0);
155                        if span_len < existing_len {
156                            *existing = span;
157                        }
158                    })
159                    .or_insert(span);
160            }
161        }
162        derived_terms.sort_by(|a, b| (a.derived, a.base).cmp(&(b.derived, b.base)));
163        derived_terms.dedup_by(|a, b| a.derived == b.derived && a.base == b.base);
164        let mut derived_spans: Vec<DerivedSpan> = derived_spans_map
165            .into_iter()
166            .map(|(derived, span)| DerivedSpan { derived, span })
167            .collect();
168        derived_spans.sort_by_key(|entry| entry.derived);
169
170        let mut term_freqs_vec: Vec<TermFrequencyEntry> = term_freqs
171            .into_iter()
172            .map(|(term, counts)| TermFrequencyEntry { term, counts })
173            .collect();
174        term_freqs_vec.sort_by_key(|entry| entry.term);
175
176        let doc_len = token_stream.doc_len;
177        let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs_vec);
178        if domain_doc_len.is_zero() {
179            domain_doc_len.add(TermDomain::Original, doc_len);
180        }
181
182        for entry in &term_freqs_vec {
183            for (domain, count) in entry.positive_domains() {
184                maps.add_posting(entry.term, domain, doc_idx, count);
185            }
186        }
187
188        let doc_data = DocData {
189            content: text.to_string(),
190            doc_len,
191            term_pos: term_positions,
192            term_freqs: term_freqs_vec,
193            domain_doc_len,
194            derived_terms,
195            derived_spans,
196        };
197
198        if maps.state.docs.len() <= doc_idx as usize {
199            maps.state.docs.resize(doc_idx as usize + 1, None);
200        }
201        maps.state.docs[doc_idx as usize] = Some(doc_data);
202
203        maps.state.total_len += doc_len;
204        domain_doc_len.for_each_nonzero(|domain, len| {
205            maps.state.domain_total_len.add(domain, len);
206        });
207
208        let doc_key = maps
209            .state
210            .doc_ids
211            .get(doc_idx as usize)
212            .cloned()
213            .unwrap_or_else(|| SmolStr::new(doc_id));
214        maps.state.dirty.insert(doc_key.clone());
215        maps.state.deleted.remove(doc_key.as_str());
216    }
217
218    /// Remove a document and its postings from an index.
219    pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
220        let mut maps = Index {
221            state: self.index_state_mut(index_name),
222        };
223        let Some(&doc_idx) = maps.state.doc_index.get(doc_id) else {
224            return;
225        };
226
227        if let Some(old_data) = maps
228            .state
229            .docs
230            .get_mut(doc_idx as usize)
231            .and_then(|slot| slot.take())
232        {
233            maps.state.total_len -= old_data.doc_len;
234            let old_domain_lengths = DomainLengths::from_doc(&old_data);
235            old_domain_lengths.for_each_nonzero(|domain, len| {
236                maps.state.domain_total_len.add(domain, -len);
237            });
238            maps.remove_doc_terms(doc_idx, &old_data);
239        }
240
241        maps.state.doc_index.remove(doc_id);
242        maps.state.free_docs.push(doc_idx);
243        let doc_key = maps
244            .state
245            .doc_ids
246            .get(doc_idx as usize)
247            .cloned()
248            .unwrap_or_else(|| SmolStr::new(doc_id));
249        maps.state.deleted.insert(doc_key);
250        maps.state.dirty.remove(doc_id);
251    }
252
253    /// Fetch raw document content by id, if present.
254    pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
255        let state = self.indexes.get(index_name)?;
256        let doc_idx = *state.doc_index.get(doc_id)? as usize;
257        state
258            .docs
259            .get(doc_idx)
260            .and_then(|doc| doc.as_ref())
261            .map(|d| d.content.clone())
262    }
263
264    /// Return and clear the sets of dirty and deleted docs for persistence.
265    pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
266        let mut dirty_data = Vec::new();
267        let mut deleted = HashMap::new();
268
269        for (index_name, state) in self.indexes.iter_mut() {
270            let dirty = std::mem::take(&mut state.dirty);
271            let deleted_ids = std::mem::take(&mut state.deleted);
272
273            for doc_id in dirty {
274                if let Some(&doc_idx) = state.doc_index.get(&doc_id)
275                    && let Some(doc) = state
276                        .docs
277                        .get(doc_idx as usize)
278                        .and_then(|entry| entry.as_ref())
279                {
280                    dirty_data.push((
281                        index_name.clone(),
282                        doc_id.to_string(),
283                        doc.content.clone(),
284                        doc.doc_len,
285                    ));
286                }
287            }
288
289            if !deleted_ids.is_empty() {
290                let deleted_strings: HashSet<String> = deleted_ids
291                    .into_iter()
292                    .map(|doc_id| doc_id.to_string())
293                    .collect();
294                deleted.insert(index_name.clone(), deleted_strings);
295            }
296        }
297
298        (dirty_data, deleted)
299    }
300
301    /// Returns true if the index has new changes awaiting persistence.
302    /// Pass `Some(name)` to check a specific index or `None` to check all.
303    pub fn has_unpersisted_changes(&self, index_name: Option<&str>) -> bool {
304        match index_name {
305            Some(name) => self
306                .indexes
307                .get(name)
308                .is_some_and(|state| !state.dirty.is_empty() || !state.deleted.is_empty()),
309            None => self
310                .indexes
311                .values()
312                .any(|state| !state.dirty.is_empty() || !state.deleted.is_empty()),
313        }
314    }
315
316    /// Persist the given index only if there are pending changes.
317    ///
318    /// Returns `Ok(true)` if persistence was attempted (and succeeded), `Ok(false)` if skipped.
319    /// The index is marked clean only after the provided callback returns `Ok`.
320    pub fn persist_if_dirty<E>(
321        &mut self,
322        index_name: &str,
323        mut persist: impl FnMut(SnapshotData) -> Result<(), E>,
324    ) -> Result<bool, E> {
325        if !self.has_unpersisted_changes(Some(index_name)) {
326            return Ok(false);
327        }
328
329        let Some(snapshot) = self.get_snapshot_data(index_name) else {
330            return Ok(false);
331        };
332
333        persist(snapshot)?;
334        if let Some(state) = self.indexes.get_mut(index_name) {
335            state.dirty.clear();
336            state.deleted.clear();
337        }
338        Ok(true)
339    }
340
341    /// Get byte/UTF-16 spans for a query's terms within a document by re-tokenizing the query.
342    pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
343        let query_terms: Vec<String> = self
344            .tokenize_query(query)
345            .into_iter()
346            .map(|t| t.term)
347            .collect();
348        self.get_matches_for_terms(index_name, doc_id, &query_terms)
349    }
350
351    /// Get spans for specific terms within a document.
352    pub fn get_matches_for_terms(
353        &self,
354        index_name: &str,
355        doc_id: &str,
356        terms: &[String],
357    ) -> Vec<(u32, u32)> {
358        let mut matches = Vec::new();
359        let Some(state) = self.indexes.get(index_name) else {
360            return matches;
361        };
362        let Some(&doc_idx) = state.doc_index.get(doc_id) else {
363            return matches;
364        };
365        let Some(doc_data) = state
366            .docs
367            .get(doc_idx as usize)
368            .and_then(|doc| doc.as_ref())
369        else {
370            return matches;
371        };
372
373        for term in terms {
374            let Some(&term_id) = state.term_index.get(term.as_str()) else {
375                continue;
376            };
377
378            let mut found = false;
379            if let Some(positions) = find_term_positions(doc_data, term_id) {
380                matches.extend(positions.iter().copied());
381                found = true;
382            }
383
384            if !found {
385                for base_term in find_base_terms(doc_data, term_id) {
386                    if let Some(positions) = find_term_positions(doc_data, base_term) {
387                        matches.extend(positions.iter().copied());
388                        found = true;
389                    }
390                }
391            }
392
393            if !found {
394                matches.extend(find_derived_spans(doc_data, term_id));
395            }
396        }
397
398        if !matches.is_empty() {
399            matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
400        }
401        matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| (a.1 - a.0).cmp(&(b.1 - b.0))));
402        matches = prune_overlapping_starts(&matches);
403        matches
404    }
405
406    /// Get spans for previously returned matched terms (e.g., from `search_hits`).
407    pub fn get_matches_for_matched_terms(
408        &self,
409        index_name: &str,
410        doc_id: &str,
411        terms: &[crate::types::MatchedTerm],
412    ) -> Vec<(u32, u32)> {
413        let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
414        self.get_matches_for_terms(index_name, doc_id, &term_strings)
415    }
416
417    /// Load a snapshot into an index, expecting all auxiliary structures to be present.
418    pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
419        if snapshot.version != SNAPSHOT_VERSION {
420            return;
421        }
422        let version = {
423            let mut maps = Index {
424                state: self.index_state_mut(index_name),
425            };
426            maps.clear();
427            maps.import_snapshot(snapshot);
428            maps.state.version
429        };
430        if let Some(state) = self.indexes.get_mut(index_name) {
431            state.version = version;
432            state.dirty.clear();
433            state.deleted.clear();
434        }
435    }
436
437    /// Get a serializable snapshot of the given index, including postings.
438    pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
439        let state = self.indexes.get(index_name)?;
440        if state.docs.iter().all(|d| d.is_none()) {
441            return None;
442        }
443
444        Some(SnapshotData {
445            version: state.version,
446            terms: state.terms.clone(),
447            docs: state.docs.clone(),
448            doc_ids: state.doc_ids.clone(),
449            domains: state.domains.clone(),
450            total_len: state.total_len,
451            domain_total_len: state.domain_total_len,
452        })
453    }
454
455    fn document_pipeline(&self) -> Pipeline {
456        if let Some(cfg) = &self.dictionary {
457            Pipeline::with_dictionary(cfg.clone())
458        } else {
459            Pipeline::document_pipeline()
460        }
461    }
462
463    pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
464        if let Some(cfg) = &self.dictionary {
465            Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
466                .query_tokens(query)
467                .tokens
468                .into_iter()
469                .map(|token| Token {
470                    term: token.term,
471                    start: token.span.0,
472                    end: token.span.1,
473                })
474                .collect()
475        } else {
476            Pipeline::tokenize_query(query)
477        }
478    }
479}
480
481fn get_or_insert_term_id(state: &mut super::types::IndexState, term: &str) -> TermId {
482    if let Some(&id) = state.term_index.get(term) {
483        return id;
484    }
485    let id = state.terms.len() as TermId;
486    let term_key = SmolStr::new(term);
487    state.terms.push(term_key.clone());
488    state.term_index.insert(term_key, id);
489    id
490}
491
492fn find_term_positions(doc: &DocData, term: TermId) -> Option<&[(u32, u32)]> {
493    let idx = doc
494        .term_pos
495        .binary_search_by_key(&term, |entry| entry.term)
496        .ok()?;
497    Some(&doc.term_pos[idx].positions)
498}
499
500fn find_base_terms(doc: &DocData, derived: TermId) -> Vec<TermId> {
501    let list = &doc.derived_terms;
502    let mut start = match list.binary_search_by_key(&derived, |entry| entry.derived) {
503        Ok(idx) => idx,
504        Err(_) => return Vec::new(),
505    };
506    while start > 0 && list[start - 1].derived == derived {
507        start -= 1;
508    }
509    let mut terms = Vec::new();
510    let mut idx = start;
511    while idx < list.len() && list[idx].derived == derived {
512        terms.push(list[idx].base);
513        idx += 1;
514    }
515    terms
516}
517
518fn find_derived_spans(doc: &DocData, derived: TermId) -> Vec<(u32, u32)> {
519    let list = &doc.derived_spans;
520    let mut start = match list.binary_search_by_key(&derived, |entry| entry.derived) {
521        Ok(idx) => idx,
522        Err(_) => return Vec::new(),
523    };
524    while start > 0 && list[start - 1].derived == derived {
525        start -= 1;
526    }
527    let mut spans = Vec::new();
528    let mut idx = start;
529    while idx < list.len() && list[idx].derived == derived {
530        spans.push(list[idx].span);
531        idx += 1;
532    }
533    spans
534}
535
536fn convert_spans(
537    content: &str,
538    spans: &[(u32, u32)],
539    encoding: PositionEncoding,
540) -> Vec<(u32, u32)> {
541    match encoding {
542        PositionEncoding::Bytes => spans.to_vec(),
543        PositionEncoding::Utf16 => spans
544            .iter()
545            .map(|(start, end)| {
546                let s = to_utf16_index(content, *start as usize);
547                let e = to_utf16_index(content, *end as usize);
548                (s as u32, e as u32)
549            })
550            .collect(),
551    }
552}
553
554fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
555    if byte_idx == 0 {
556        return 0;
557    }
558    let prefix = &content[..byte_idx.min(content.len())];
559    prefix.encode_utf16().count()
560}
561
562fn prune_overlapping_starts(spans: &[(u32, u32)]) -> Vec<(u32, u32)> {
563    if spans.is_empty() {
564        return Vec::new();
565    }
566    let mut pruned = Vec::new();
567    let mut i = 0;
568    while i < spans.len() {
569        let start = spans[i].0;
570        let mut best = spans[i];
571        let mut j = i + 1;
572        while j < spans.len() && spans[j].0 == start {
573            let best_len = best.1 - best.0;
574            let cur_len = spans[j].1 - spans[j].0;
575            if cur_len < best_len {
576                best = spans[j];
577            }
578            j += 1;
579        }
580        pruned.push(best);
581        i = j;
582    }
583    pruned
584}