Skip to main content

mcp_memory/
search.rs

1use crate::intern::StrId;
2
3/// Inverted word index for fast entity search.
4///
5/// For each entity, we tokenize its name, type, and observations,
6/// store each token → set of matching entity indices.
7///
8/// Uses a flat `Vec<(StrId, u32)>` sorted by (token, entity_idx)
9/// for cache-friendly lookups via binary search.
10pub struct SearchIndex {
11    // Sorted by (token, entity_idx), no duplicates.
12    entries: Vec<(StrId, u32)>,
13    lower_buf: Vec<u8>,
14}
15
16impl SearchIndex {
17    pub fn new() -> Self {
18        Self {
19            entries: Vec::new(),
20            lower_buf: Vec::with_capacity(256),
21        }
22    }
23
24    pub fn clear(&mut self) {
25        self.entries.clear();
26    }
27
28    pub const fn len(&self) -> usize {
29        self.entries.len()
30    }
31
32    pub const fn is_empty(&self) -> bool {
33        self.entries.is_empty()
34    }
35
36    /// Index a single entity by its name, type, and observations.
37    /// All strings must already be interned.
38    /// `entity_idx` is the position in the entity storage vec.
39    pub fn index_entity(
40        &mut self,
41        interner: &mut crate::intern::StringInterner,
42        entity_idx: u32,
43        name: StrId,
44        entity_type: StrId,
45        observations: &[StrId],
46    ) {
47        let mut texts = Vec::with_capacity(2 + observations.len());
48        texts.push(name);
49        texts.push(entity_type);
50        texts.extend_from_slice(observations);
51        self.insert_tokens(interner, entity_idx, &texts);
52    }
53
54    /// Incrementally index additional strings (e.g. newly added observations)
55    /// for an entity that is *already* indexed, without removing and rebuilding
56    /// its existing entries (P3). Token entries that already exist are deduped
57    /// during the merge, so calling this with text that overlaps existing
58    /// tokens is safe.
59    pub fn index_additional(
60        &mut self,
61        interner: &mut crate::intern::StringInterner,
62        entity_idx: u32,
63        texts: &[StrId],
64    ) {
65        self.insert_tokens(interner, entity_idx, texts);
66    }
67
68    /// Remove all entries for a given entity (before re-indexing).
69    pub fn remove_entity(&mut self, entity_idx: u32) {
70        self.entries.retain(|&(_, idx)| idx != entity_idx);
71    }
72
73    /// Search for entities whose name/type/observation tokens match `query`
74    /// case-insensitively by **prefix** (`"cof"` matches `"coffee"`).
75    ///
76    /// Note: this is an O(n) scan over every index entry. The binary-search
77    /// step below only narrows exact-token hits, but the subsequent prefix scan
78    /// already covers those (an exact match is also a prefix match), so the scan
79    /// dominates — do not read the binary search as making this sublinear.
80    pub fn search(&self, query: &str, interner: &crate::intern::StringInterner) -> Vec<u32> {
81        if query.is_empty() || self.entries.is_empty() {
82            return Vec::new();
83        }
84
85        let lower_query: String = query.to_ascii_lowercase();
86
87        // Fast path: exact token match via binary search
88        let mut matched = if let Some(token_id) = interner.get_optional(&lower_query) {
89            let range_begin = self.entries.binary_search_by(|(t, _)| t.cmp(&token_id));
90            let range_end = self.entries.binary_search_by(|(t, _)| {
91                if *t <= token_id { std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater }
92            });
93            if let (Ok(begin), Err(end)) = (range_begin, range_end) {
94                self.entries[begin..end].iter().map(|&(_, idx)| idx).collect()
95            } else {
96                Vec::new()
97            }
98        } else {
99            Vec::new()
100        };
101
102        // Prefix match scan for tokens starting with the query
103        for &(token_id, entity_idx) in &self.entries {
104            if matched.last().is_none_or(|&last| last != entity_idx) {
105                let token = interner.lookup(token_id);
106                if token.len() >= lower_query.len()
107                    && token.as_bytes().starts_with(lower_query.as_bytes())
108                {
109                    matched.push(entity_idx);
110                }
111            }
112        }
113
114        matched.sort_unstable();
115        matched.dedup();
116        matched
117    }
118
119    /// Like [`search`], but returns `(entity_idx, score)` pairs sorted by
120    /// descending score (then ascending idx for stability). `score` is the
121    /// number of indexed-token hits the entity accumulated for the query —
122    /// a cheap relevance proxy so callers can surface the best matches first.
123    ///
124    /// The scan is a single linear pass over the flat `entries` vec (no
125    /// per-entity allocation until the final compaction), keeping it
126    /// cache-friendly. A small `Vec<(idx, score)>` is gathered then sorted.
127    pub fn search_ranked(&self, query: &str, interner: &crate::intern::StringInterner) -> Vec<(u32, u32)> {
128        if query.is_empty() || self.entries.is_empty() {
129            return Vec::new();
130        }
131
132        let lower_query: String = query.to_ascii_lowercase();
133        let qbytes = lower_query.as_bytes();
134        let qlen = qbytes.len();
135
136        // Exact-token id (if the query is itself an interned token) lets us
137        // score exact hits without a string compare.
138        let exact_id = interner.get_optional(&lower_query);
139
140        // (idx, score) gathered in one pass, idx-major so equal idxs are adjacent.
141        let mut hits: Vec<(u32, u32)> = Vec::new();
142        for &(token_id, entity_idx) in &self.entries {
143            let matches = if Some(token_id) == exact_id {
144                true
145            } else {
146                let token = interner.lookup(token_id);
147                token.len() >= qlen && token.as_bytes().starts_with(qbytes)
148            };
149            if matches {
150                match hits.last_mut() {
151                    Some(last) if last.0 == entity_idx => last.1 += 1,
152                    _ => hits.push((entity_idx, 1)),
153                }
154            }
155        }
156
157        // entries are sorted by (token, idx), so a single entity_idx may appear
158        // in non-adjacent groups (once per matching token). Merge by idx, then
159        // rank by score desc.
160        hits.sort_unstable_by_key(|&(idx, _)| idx);
161        let mut merged: Vec<(u32, u32)> = Vec::with_capacity(hits.len());
162        for (idx, score) in hits {
163            match merged.last_mut() {
164                Some(last) if last.0 == idx => last.1 += score,
165                _ => merged.push((idx, score)),
166            }
167        }
168        merged.sort_unstable_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
169        merged
170    }
171
172    /// Tokenize every string in `texts`, collect the resulting
173    /// `(token, entity_idx)` entries, and merge them into the sorted `entries`
174    /// vec in a single O(N + K) pass (P2). This replaces the previous
175    /// per-token `Vec::insert`, which shifted the tail on every token and made
176    /// indexing O(K × N).
177    fn insert_tokens(
178        &mut self,
179        interner: &mut crate::intern::StringInterner,
180        entity_idx: u32,
181        texts: &[StrId],
182    ) {
183        let mut additions: Vec<(StrId, u32)> = Vec::new();
184        for &text in texts {
185            let s = interner.lookup(text);
186            if s.is_empty() {
187                continue;
188            }
189            self.lower_buf.clear();
190            self.lower_buf.extend(s.bytes().map(|b| b.to_ascii_lowercase()));
191            let lowered = unsafe { std::str::from_utf8_unchecked(&self.lower_buf) };
192            let tokens: Vec<&str> =
193                lowered.split_whitespace().filter(|t| !t.is_empty()).collect();
194            for token in tokens {
195                additions.push((interner.intern(token), entity_idx));
196            }
197        }
198        if additions.is_empty() {
199            return;
200        }
201        additions.sort_unstable();
202        additions.dedup();
203        self.merge_entries(&additions);
204    }
205
206    /// Merge a pre-sorted, deduped slice of new entries into `entries`
207    /// (also sorted and deduped) in one linear pass. Entries already present
208    /// are skipped, preserving the no-duplicate invariant.
209    fn merge_entries(&mut self, additions: &[(StrId, u32)]) {
210        let old = std::mem::take(&mut self.entries);
211        let mut merged = Vec::with_capacity(old.len() + additions.len());
212        let (mut i, mut j) = (0, 0);
213        while i < old.len() && j < additions.len() {
214            match old[i].cmp(&additions[j]) {
215                std::cmp::Ordering::Less => {
216                    merged.push(old[i]);
217                    i += 1;
218                }
219                std::cmp::Ordering::Greater => {
220                    merged.push(additions[j]);
221                    j += 1;
222                }
223                std::cmp::Ordering::Equal => {
224                    merged.push(old[i]);
225                    i += 1;
226                    j += 1;
227                }
228            }
229        }
230        merged.extend_from_slice(&old[i..]);
231        merged.extend_from_slice(&additions[j..]);
232        self.entries = merged;
233    }
234}
235
236impl Default for SearchIndex {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::intern::StringInterner;
246
247    #[test]
248    fn test_index_and_search() {
249        let mut interner = StringInterner::new();
250        let mut index = SearchIndex::new();
251
252        let alice_name = interner.intern("Alice");
253        let alice_type = interner.intern("person");
254        let alice_obs = interner.intern("likes coffee");
255
256        index.index_entity(&mut interner, 0, alice_name, alice_type, &[alice_obs]);
257
258        let bob_name = interner.intern("Bob");
259        let bob_type = interner.intern("person");
260        let bob_obs = interner.intern("drinks tea");
261
262        index.index_entity(&mut interner, 1, bob_name, bob_type, &[bob_obs]);
263
264        let results = index.search("coffee", &interner);
265        assert_eq!(results, vec![0]);
266
267        let results = index.search("person", &interner);
268        assert_eq!(results.len(), 2);
269    }
270
271    #[test]
272    fn test_remove_entity() {
273        let mut interner = StringInterner::new();
274        let mut index = SearchIndex::new();
275
276        let name = interner.intern("Test");
277        let typ = interner.intern("type");
278
279        index.index_entity(&mut interner, 0, name, typ, &[]);
280        assert!(!index.is_empty());
281
282        index.remove_entity(0);
283        assert!(index.entries.iter().all(|&(_, idx)| idx != 0));
284    }
285
286    #[test]
287    fn test_search_empty_query() {
288        let interner = StringInterner::new();
289        let index = SearchIndex::new();
290        assert!(index.search("", &interner).is_empty());
291    }
292
293    #[test]
294    fn test_search_no_match() {
295        let mut interner = StringInterner::new();
296        let mut index = SearchIndex::new();
297        let name = interner.intern("Alice");
298        let typ = interner.intern("person");
299        index.index_entity(&mut interner, 0, name, typ, &[]);
300        assert!(index.search("zzzzzz", &interner).is_empty());
301    }
302
303    #[test]
304    fn test_search_case_insensitive() {
305        let mut interner = StringInterner::new();
306        let mut index = SearchIndex::new();
307        let name = interner.intern("Alice");
308        let typ = interner.intern("person");
309        index.index_entity(&mut interner, 0, name, typ, &[]);
310        let results = index.search("ALICE", &interner);
311        assert_eq!(results, vec![0]);
312    }
313
314    #[test]
315    fn test_search_partial_substring() {
316        let mut interner = StringInterner::new();
317        let mut index = SearchIndex::new();
318        let name = interner.intern("Alice");
319        let typ = interner.intern("person");
320        index.index_entity(&mut interner, 0, name, typ, &[]);
321        let results = index.search("Ali", &interner);
322        assert_eq!(results, vec![0]);
323    }
324
325    #[test]
326    fn test_multi_token_search() {
327        let mut interner = StringInterner::new();
328        let mut index = SearchIndex::new();
329        let obs = interner.intern("likes drinking coffee");
330        let alice = interner.intern("Alice");
331        let person = interner.intern("person");
332        index.index_entity(
333            &mut interner,
334            0,
335            alice,
336            person,
337            &[obs],
338        );
339        assert_eq!(index.search("likes", &interner), vec![0]);
340        assert_eq!(index.search("drinking", &interner), vec![0]);
341        assert_eq!(index.search("coffee", &interner), vec![0]);
342    }
343
344    #[test]
345    fn test_remove_then_reindex() {
346        let mut interner = StringInterner::new();
347        let mut index = SearchIndex::new();
348        let name = interner.intern("Alice");
349        let typ = interner.intern("person");
350        index.index_entity(&mut interner, 0, name, typ, &[]);
351
352        assert_eq!(index.search("Alice", &interner).len(), 1);
353        index.remove_entity(0);
354        assert!(index.search("Alice", &interner).is_empty());
355
356        index.index_entity(&mut interner, 0, name, typ, &[]);
357        assert_eq!(index.search("Alice", &interner).len(), 1);
358    }
359
360    #[test]
361    fn test_query_longer_than_token() {
362        let mut interner = StringInterner::new();
363        let mut index = SearchIndex::new();
364        let name = interner.intern("Alice");
365        let person = interner.intern("person");
366        index.index_entity(&mut interner, 0, name, person, &[]);
367        assert!(index.search("AliceInWonderland", &interner).is_empty());
368    }
369
370    #[test]
371    fn test_empty_index() {
372        let interner = StringInterner::new();
373        let index = SearchIndex::new();
374        assert!(index.search("anything", &interner).is_empty());
375    }
376
377    #[test]
378    fn test_search_ranked_orders_by_score() {
379        let mut interner = StringInterner::new();
380        let mut index = SearchIndex::new();
381        // Entity 0: "coffee" appears in both name and observation → score 2.
382        let n0 = interner.intern("coffee");
383        let t0 = interner.intern("drink");
384        let o0 = interner.intern("coffee beans");
385        index.index_entity(&mut interner, 0, n0, t0, &[o0]);
386        // Entity 1: "coffee" appears once (observation only) → score 1.
387        let n1 = interner.intern("Bob");
388        let t1 = interner.intern("person");
389        let o1 = interner.intern("likes coffee");
390        index.index_entity(&mut interner, 1, n1, t1, &[o1]);
391
392        let ranked = index.search_ranked("coffee", &interner);
393        assert_eq!(ranked.len(), 2);
394        // Higher score first.
395        assert_eq!(ranked[0].0, 0);
396        assert!(ranked[0].1 >= ranked[1].1);
397        assert_eq!(ranked[1].0, 1);
398    }
399
400    #[test]
401    fn test_search_ranked_empty_query() {
402        let interner = StringInterner::new();
403        let index = SearchIndex::new();
404        assert!(index.search_ranked("", &interner).is_empty());
405    }
406
407    #[test]
408    fn test_search_is_prefix_not_substring() {
409        let mut interner = StringInterner::new();
410        let mut index = SearchIndex::new();
411        let name = interner.intern("coffee");
412        let typ = interner.intern("drink");
413        index.index_entity(&mut interner, 0, name, typ, &[]);
414        // Prefix of a token matches.
415        assert_eq!(index.search("cof", &interner), vec![0]);
416        assert_eq!(index.search("coffee", &interner), vec![0]);
417        // Interior substrings do NOT match — this documents real behavior, not
418        // the "substring search" the docs once claimed.
419        assert!(index.search("ffee", &interner).is_empty());
420        assert!(index.search("offe", &interner).is_empty());
421    }
422
423    #[test]
424    fn test_search_ranked_is_prefix_not_substring() {
425        let mut interner = StringInterner::new();
426        let mut index = SearchIndex::new();
427        let name = interner.intern("coffee");
428        let typ = interner.intern("drink");
429        index.index_entity(&mut interner, 0, name, typ, &[]);
430        assert!(!index.search_ranked("cof", &interner).is_empty());
431        assert!(index.search_ranked("ffee", &interner).is_empty());
432    }
433
434    #[test]
435    fn test_clear_index() {
436        let mut interner = StringInterner::new();
437        let mut index = SearchIndex::new();
438        let name = interner.intern("Alice");
439        let person = interner.intern("person");
440        index.index_entity(&mut interner, 0, name, person, &[]);
441        assert!(!index.is_empty());
442        index.clear();
443        assert!(index.is_empty());
444        assert!(index.search("Alice", &interner).is_empty());
445    }
446}