memory_indexer/
base.rs

1use std::collections::{HashMap, HashSet};
2
3use super::{
4    pipeline::{DefaultTokenizer, Pipeline},
5    tokenizer::Token,
6    types::{
7        DocData, DomainLengths, DomainSnapshot, InMemoryIndex, PipelineToken, PositionEncoding,
8        SNAPSHOT_VERSION, SnapshotData, TermDomain, TokenStream,
9    },
10};
11
12type DirtyDoc = (String, String, String, i64);
13type DeletedDoc = HashMap<String, HashSet<String>>;
14
15impl InMemoryIndex {
16    /// Create an index that returns match spans in the given encoding.
17    pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
18        let mut index = Self::default();
19        index.position_encoding = encoding;
20        index
21    }
22
23    /// Create an index that uses a custom dictionary for tokenization.
24    pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
25        let mut index = Self::default();
26        index.dictionary = Some(dictionary);
27        index
28    }
29
30    /// Set the encoding (bytes or UTF-16) used when returning match spans.
31    pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
32        self.position_encoding = encoding;
33    }
34
35    /// Swap in or remove a dictionary config for future tokenization.
36    pub fn set_dictionary_config(
37        &mut self,
38        dictionary: Option<crate::tokenizer::DictionaryConfig>,
39    ) {
40        self.dictionary = dictionary;
41    }
42
43    /// Add or replace a document in an index. Set `index` to false to stage content without
44    /// tokenization (doc will exist but not be searchable).
45    pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
46        let token_stream = if index {
47            self.document_pipeline().document_tokens(text)
48        } else {
49            TokenStream {
50                tokens: Vec::new(),
51                term_freqs: HashMap::new(),
52                doc_len: 0,
53            }
54        };
55
56        let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
57        let mut derived_mapping: HashMap<String, HashSet<(u32, u32)>> = HashMap::new();
58        for PipelineToken {
59            term, span, domain, ..
60        } in &token_stream.tokens
61        {
62            if *domain == TermDomain::Original {
63                pos_map
64                    .entry(term.clone())
65                    .or_default()
66                    .push((span.0 as u32, span.1 as u32));
67            } else {
68                derived_mapping
69                    .entry(term.clone())
70                    .or_default()
71                    .insert((span.0 as u32, span.1 as u32));
72            }
73        }
74        let doc_len = token_stream.doc_len;
75        let term_freqs = token_stream.term_freqs;
76        let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs);
77        if domain_doc_len.is_zero() {
78            domain_doc_len.add(TermDomain::Original, doc_len);
79        }
80
81        if let Some(docs) = self.docs.get_mut(index_name) {
82            if let Some(old_data) = docs.remove(doc_id) {
83                *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
84
85                let old_domain_lengths = DomainLengths::from_doc(&old_data);
86                if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
87                    old_domain_lengths.for_each_nonzero(|domain, len| {
88                        total_by_domain.add(domain, -len);
89                    });
90                }
91
92                self.index_maps_mut(index_name)
93                    .remove_doc_terms(doc_id, &old_data);
94            }
95        }
96
97        let mut writer = self.index_writer(index_name, doc_id);
98        for (term, freqs) in &term_freqs {
99            writer.add_term_frequency(term, freqs);
100        }
101
102        let doc_data = DocData {
103            content: text.to_string(),
104            doc_len,
105            term_pos: pos_map,
106            term_freqs,
107            domain_doc_len: domain_doc_len.clone(),
108            derived_terms: derived_mapping
109                .into_iter()
110                .map(|(k, v)| {
111                    let mut spans: Vec<(u32, u32)> = v.into_iter().collect();
112                    spans.sort();
113                    spans.dedup();
114                    if let Some(min_len) = spans.iter().map(|(s, e)| e - s).min() {
115                        spans.retain(|(s, e)| e - s == min_len);
116                    }
117                    (k, spans)
118                })
119                .collect(),
120        };
121
122        self.docs
123            .entry(index_name.to_string())
124            .or_default()
125            .insert(doc_id.to_string(), doc_data);
126        *self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
127        let total_by_domain = self
128            .domain_total_lens
129            .entry(index_name.to_string())
130            .or_default();
131        domain_doc_len.for_each_nonzero(|domain, len| {
132            total_by_domain.add(domain, len);
133        });
134
135        self.dirty
136            .entry(index_name.to_string())
137            .or_default()
138            .insert(doc_id.to_string());
139        if let Some(deleted) = self.deleted.get_mut(index_name) {
140            deleted.remove(doc_id);
141        }
142    }
143
144    /// Remove a document and its postings from an index.
145    pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
146        if let Some(docs) = self.docs.get_mut(index_name) {
147            if let Some(old_data) = docs.remove(doc_id) {
148                *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
149
150                let old_domain_lengths = DomainLengths::from_doc(&old_data);
151                if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
152                    old_domain_lengths.for_each_nonzero(|domain, len| {
153                        total_by_domain.add(domain, -len);
154                    });
155                }
156
157                self.index_maps_mut(index_name)
158                    .remove_doc_terms(doc_id, &old_data);
159
160                self.deleted
161                    .entry(index_name.to_string())
162                    .or_default()
163                    .insert(doc_id.to_string());
164                if let Some(dirty) = self.dirty.get_mut(index_name) {
165                    dirty.remove(doc_id);
166                }
167            }
168        }
169    }
170
171    /// Fetch raw document content by id, if present.
172    pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
173        self.docs
174            .get(index_name)
175            .and_then(|docs| docs.get(doc_id))
176            .map(|d| d.content.clone())
177    }
178
179    /// Return and clear the sets of dirty and deleted docs for persistence.
180    pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
181        let dirty = std::mem::take(&mut self.dirty);
182        let deleted = std::mem::take(&mut self.deleted);
183
184        let mut dirty_data = Vec::new();
185        for (index_name, doc_ids) in &dirty {
186            if let Some(docs) = self.docs.get(index_name) {
187                for doc_id in doc_ids {
188                    if let Some(data) = docs.get(doc_id) {
189                        dirty_data.push((
190                            index_name.clone(),
191                            doc_id.clone(),
192                            data.content.clone(),
193                            data.doc_len,
194                        ));
195                    }
196                }
197            }
198        }
199        (dirty_data, deleted)
200    }
201
202    /// Get byte/UTF-16 spans for a query's terms within a document by re-tokenizing the query.
203    pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
204        let query_terms: Vec<String> = self
205            .tokenize_query(query)
206            .into_iter()
207            .map(|t| t.term)
208            .collect();
209        self.get_matches_for_terms(index_name, doc_id, &query_terms)
210    }
211
212    /// Get spans for specific terms within a document.
213    pub fn get_matches_for_terms(
214        &self,
215        index_name: &str,
216        doc_id: &str,
217        terms: &[String],
218    ) -> Vec<(u32, u32)> {
219        let mut matches = Vec::new();
220        if let Some(docs) = self.docs.get(index_name) {
221            if let Some(doc_data) = docs.get(doc_id) {
222                for term in terms {
223                    if let Some(positions) = doc_data.term_pos.get(term) {
224                        matches.extend(positions.iter().cloned());
225                        continue;
226                    }
227                    if let Some(positions) = doc_data.derived_terms.get(term) {
228                        matches.extend(positions.iter().cloned());
229                    }
230                }
231                if !matches.is_empty() {
232                    matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
233                }
234            }
235        }
236        matches.sort_by(|a, b| a.0.cmp(&b.0));
237        matches
238    }
239
240    /// Get spans for previously returned matched terms (e.g., from `search_hits`).
241    pub fn get_matches_for_matched_terms(
242        &self,
243        index_name: &str,
244        doc_id: &str,
245        terms: &[crate::types::MatchedTerm],
246    ) -> Vec<(u32, u32)> {
247        let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
248        self.get_matches_for_terms(index_name, doc_id, &term_strings)
249    }
250
251    /// Load a snapshot into an index, rebuilding missing auxiliary structures if needed.
252    pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
253        let version = {
254            let mut maps = self.index_maps_mut(index_name);
255            maps.clear(false);
256            maps.import_snapshot(snapshot);
257            maps.version
258        };
259        self.versions.insert(index_name.to_string(), version);
260    }
261
262    /// Get a serializable snapshot of the given index, including aux dictionaries/ngrams.
263    pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
264        self.docs.get(index_name).map(|docs| {
265            let domains = self
266                .domains
267                .get(index_name)
268                .cloned()
269                .unwrap_or_default()
270                .into_iter()
271                .map(|(domain, data)| {
272                    (
273                        domain,
274                        DomainSnapshot {
275                            term_dict: data.term_dict,
276                            ngram_index: data.ngram_index,
277                        },
278                    )
279                })
280                .collect();
281
282            SnapshotData {
283                version: *self.versions.get(index_name).unwrap_or(&SNAPSHOT_VERSION),
284                docs: docs.clone(),
285                domains,
286            }
287        })
288    }
289
290    fn document_pipeline(&self) -> Pipeline {
291        if let Some(cfg) = &self.dictionary {
292            Pipeline::with_dictionary(cfg.clone())
293        } else {
294            Pipeline::document_pipeline()
295        }
296    }
297
298    pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
299        if let Some(cfg) = &self.dictionary {
300            Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
301                .query_tokens(query)
302                .tokens
303                .into_iter()
304                .map(|token| Token {
305                    term: token.term,
306                    start: token.span.0,
307                    end: token.span.1,
308                })
309                .collect()
310        } else {
311            Pipeline::tokenize_query(query)
312        }
313    }
314}
315
316fn convert_spans(
317    content: &str,
318    spans: &[(u32, u32)],
319    encoding: PositionEncoding,
320) -> Vec<(u32, u32)> {
321    match encoding {
322        PositionEncoding::Bytes => spans.to_vec(),
323        PositionEncoding::Utf16 => spans
324            .iter()
325            .map(|(start, end)| {
326                let s = to_utf16_index(content, *start as usize);
327                let e = to_utf16_index(content, *end as usize);
328                (s as u32, e as u32)
329            })
330            .collect(),
331    }
332}
333
334fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
335    if byte_idx == 0 {
336        return 0;
337    }
338    let prefix = &content[..byte_idx.min(content.len())];
339    prefix.encode_utf16().count()
340}