memory_indexer/
base.rs

1use std::collections::{HashMap, HashSet};
2
3use super::{
4    pipeline::{DefaultTokenizer, Pipeline},
5    tokenizer::Token,
6    types::{
7        DocData, DomainLengths, InMemoryIndex, PipelineToken, PositionEncoding, SNAPSHOT_VERSION,
8        SnapshotData, TermDomain, TokenStream,
9    },
10};
11
12type DirtyDoc = (String, String, String, i64);
13type DeletedDoc = HashMap<String, HashSet<String>>;
14
15impl InMemoryIndex {
16    /// Create an index that returns match spans in the given encoding.
17    pub fn with_position_encoding(encoding: PositionEncoding) -> Self {
18        let mut index = Self::default();
19        index.position_encoding = encoding;
20        index
21    }
22
23    /// Create an index that uses a custom dictionary for tokenization.
24    pub fn with_dictionary_config(dictionary: crate::tokenizer::DictionaryConfig) -> Self {
25        let mut index = Self::default();
26        index.dictionary = Some(dictionary);
27        index
28    }
29
30    /// Set the encoding (bytes or UTF-16) used when returning match spans.
31    pub fn set_position_encoding(&mut self, encoding: PositionEncoding) {
32        self.position_encoding = encoding;
33    }
34
35    /// Swap in or remove a dictionary config for future tokenization.
36    pub fn set_dictionary_config(
37        &mut self,
38        dictionary: Option<crate::tokenizer::DictionaryConfig>,
39    ) {
40        self.dictionary = dictionary;
41    }
42
43    /// Add or replace a document in an index. Set `index` to false to stage content without
44    /// tokenization (doc will exist but not be searchable).
45    pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
46        let token_stream = if index {
47            self.document_pipeline().document_tokens(text)
48        } else {
49            TokenStream {
50                tokens: Vec::new(),
51                term_freqs: HashMap::new(),
52                doc_len: 0,
53            }
54        };
55
56        let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
57        let mut derived_mapping: HashMap<String, HashSet<(u32, u32)>> = HashMap::new();
58        for PipelineToken {
59            term, span, domain, ..
60        } in &token_stream.tokens
61        {
62            if *domain == TermDomain::Original {
63                pos_map
64                    .entry(term.clone())
65                    .or_default()
66                    .push((span.0 as u32, span.1 as u32));
67            } else {
68                derived_mapping
69                    .entry(term.clone())
70                    .or_default()
71                    .insert((span.0 as u32, span.1 as u32));
72            }
73        }
74        let doc_len = token_stream.doc_len;
75        let term_freqs = token_stream.term_freqs;
76        let mut domain_doc_len = DomainLengths::from_term_freqs(&term_freqs);
77        if domain_doc_len.is_zero() {
78            domain_doc_len.add(TermDomain::Original, doc_len);
79        }
80
81        if let Some(docs) = self.docs.get_mut(index_name) {
82            if let Some(old_data) = docs.remove(doc_id) {
83                *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
84
85                let old_domain_lengths = DomainLengths::from_doc(&old_data);
86                if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
87                    old_domain_lengths.for_each_nonzero(|domain, len| {
88                        total_by_domain.add(domain, -len);
89                    });
90                }
91
92                self.index_maps_mut(index_name)
93                    .remove_doc_terms(doc_id, &old_data);
94            }
95        }
96
97        let mut writer = self.index_writer(index_name, doc_id);
98        for (term, freqs) in &term_freqs {
99            writer.add_term_frequency(term, freqs);
100        }
101
102        let doc_data = DocData {
103            content: text.to_string(),
104            doc_len,
105            term_pos: pos_map,
106            term_freqs,
107            domain_doc_len: domain_doc_len.clone(),
108            derived_terms: derived_mapping
109                .into_iter()
110                .map(|(k, v)| {
111                    let mut spans: Vec<(u32, u32)> = v.into_iter().collect();
112                    spans.sort();
113                    spans.dedup();
114                    if let Some(min_len) = spans.iter().map(|(s, e)| e - s).min() {
115                        spans.retain(|(s, e)| e - s == min_len);
116                    }
117                    (k, spans)
118                })
119                .collect(),
120        };
121
122        self.docs
123            .entry(index_name.to_string())
124            .or_default()
125            .insert(doc_id.to_string(), doc_data);
126        *self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
127        let total_by_domain = self
128            .domain_total_lens
129            .entry(index_name.to_string())
130            .or_default();
131        domain_doc_len.for_each_nonzero(|domain, len| {
132            total_by_domain.add(domain, len);
133        });
134
135        self.dirty
136            .entry(index_name.to_string())
137            .or_default()
138            .insert(doc_id.to_string());
139        if let Some(deleted) = self.deleted.get_mut(index_name) {
140            deleted.remove(doc_id);
141        }
142    }
143
144    /// Remove a document and its postings from an index.
145    pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
146        if let Some(docs) = self.docs.get_mut(index_name) {
147            if let Some(old_data) = docs.remove(doc_id) {
148                *self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
149
150                let old_domain_lengths = DomainLengths::from_doc(&old_data);
151                if let Some(total_by_domain) = self.domain_total_lens.get_mut(index_name) {
152                    old_domain_lengths.for_each_nonzero(|domain, len| {
153                        total_by_domain.add(domain, -len);
154                    });
155                }
156
157                self.index_maps_mut(index_name)
158                    .remove_doc_terms(doc_id, &old_data);
159
160                self.deleted
161                    .entry(index_name.to_string())
162                    .or_default()
163                    .insert(doc_id.to_string());
164                if let Some(dirty) = self.dirty.get_mut(index_name) {
165                    dirty.remove(doc_id);
166                }
167            }
168        }
169    }
170
171    /// Fetch raw document content by id, if present.
172    pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
173        self.docs
174            .get(index_name)
175            .and_then(|docs| docs.get(doc_id))
176            .map(|d| d.content.clone())
177    }
178
179    /// Return and clear the sets of dirty and deleted docs for persistence.
180    pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
181        let dirty = std::mem::take(&mut self.dirty);
182        let deleted = std::mem::take(&mut self.deleted);
183
184        let mut dirty_data = Vec::new();
185        for (index_name, doc_ids) in &dirty {
186            if let Some(docs) = self.docs.get(index_name) {
187                for doc_id in doc_ids {
188                    if let Some(data) = docs.get(doc_id) {
189                        dirty_data.push((
190                            index_name.clone(),
191                            doc_id.clone(),
192                            data.content.clone(),
193                            data.doc_len,
194                        ));
195                    }
196                }
197            }
198        }
199        (dirty_data, deleted)
200    }
201
202    /// Get byte/UTF-16 spans for a query's terms within a document by re-tokenizing the query.
203    pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
204        let query_terms: Vec<String> = self
205            .tokenize_query(query)
206            .into_iter()
207            .map(|t| t.term)
208            .collect();
209        self.get_matches_for_terms(index_name, doc_id, &query_terms)
210    }
211
212    /// Get spans for specific terms within a document.
213    pub fn get_matches_for_terms(
214        &self,
215        index_name: &str,
216        doc_id: &str,
217        terms: &[String],
218    ) -> Vec<(u32, u32)> {
219        let mut matches = Vec::new();
220        if let Some(docs) = self.docs.get(index_name) {
221            if let Some(doc_data) = docs.get(doc_id) {
222                for term in terms {
223                    if let Some(positions) = doc_data.term_pos.get(term) {
224                        matches.extend(positions.iter().cloned());
225                        continue;
226                    }
227                    if let Some(positions) = doc_data.derived_terms.get(term) {
228                        matches.extend(positions.iter().cloned());
229                    }
230                }
231                if !matches.is_empty() {
232                    matches = convert_spans(&doc_data.content, &matches, self.position_encoding);
233                }
234            }
235        }
236        matches.sort_by(|a, b| a.0.cmp(&b.0));
237        matches
238    }
239
240    /// Get spans for previously returned matched terms (e.g., from `search_hits`).
241    pub fn get_matches_for_matched_terms(
242        &self,
243        index_name: &str,
244        doc_id: &str,
245        terms: &[crate::types::MatchedTerm],
246    ) -> Vec<(u32, u32)> {
247        let term_strings: Vec<String> = terms.iter().map(|t| t.term.clone()).collect();
248        self.get_matches_for_terms(index_name, doc_id, &term_strings)
249    }
250
251    /// Load a snapshot into an index, expecting all auxiliary structures to be present.
252    pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
253        assert_eq!(
254            snapshot.version, SNAPSHOT_VERSION,
255            "snapshot version {} does not match expected {}",
256            snapshot.version, SNAPSHOT_VERSION
257        );
258        let version = {
259            let mut maps = self.index_maps_mut(index_name);
260            maps.clear(false);
261            maps.import_snapshot(snapshot);
262            maps.version
263        };
264        self.versions.insert(index_name.to_string(), version);
265    }
266
267    /// Get a serializable snapshot of the given index, including aux dictionaries/ngrams.
268    pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
269        self.docs.get(index_name).map(|docs| {
270            let domains = self.domains.get(index_name).cloned().unwrap_or_default();
271
272            SnapshotData {
273                version: *self.versions.get(index_name).unwrap_or(&SNAPSHOT_VERSION),
274                docs: docs.clone(),
275                total_len: *self.total_lens.get(index_name).unwrap_or(&0),
276                domain_total_len: self
277                    .domain_total_lens
278                    .get(index_name)
279                    .cloned()
280                    .unwrap_or_default(),
281                domains,
282            }
283        })
284    }
285
286    fn document_pipeline(&self) -> Pipeline {
287        if let Some(cfg) = &self.dictionary {
288            Pipeline::with_dictionary(cfg.clone())
289        } else {
290            Pipeline::document_pipeline()
291        }
292    }
293
294    pub(super) fn tokenize_query(&self, query: &str) -> Vec<Token> {
295        if let Some(cfg) = &self.dictionary {
296            Pipeline::new(DefaultTokenizer::for_queries().with_dictionary(cfg.clone()))
297                .query_tokens(query)
298                .tokens
299                .into_iter()
300                .map(|token| Token {
301                    term: token.term,
302                    start: token.span.0,
303                    end: token.span.1,
304                })
305                .collect()
306        } else {
307            Pipeline::tokenize_query(query)
308        }
309    }
310}
311
312fn convert_spans(
313    content: &str,
314    spans: &[(u32, u32)],
315    encoding: PositionEncoding,
316) -> Vec<(u32, u32)> {
317    match encoding {
318        PositionEncoding::Bytes => spans.to_vec(),
319        PositionEncoding::Utf16 => spans
320            .iter()
321            .map(|(start, end)| {
322                let s = to_utf16_index(content, *start as usize);
323                let e = to_utf16_index(content, *end as usize);
324                (s as u32, e as u32)
325            })
326            .collect(),
327    }
328}
329
330fn to_utf16_index(content: &str, byte_idx: usize) -> usize {
331    if byte_idx == 0 {
332        return 0;
333    }
334    let prefix = &content[..byte_idx.min(content.len())];
335    prefix.encode_utf16().count()
336}