Skip to main content

mcp_memory/
search.rs

1use crate::intern::StrId;
2
3/// Inverted word index for fast entity search.
4///
5/// For each entity, we tokenize its name, type, and observations,
6/// store each token → set of matching entity indices.
7///
8/// Uses a flat `Vec<(StrId, u32)>` sorted by (token, entity_idx)
9/// for cache-friendly lookups via binary search.
10#[derive(Clone)]
11pub struct SearchIndex {
12    // Sorted by (token, entity_idx), no duplicates.
13    entries: Vec<(StrId, u32)>,
14    lower_buf: Vec<u8>,
15}
16
17impl SearchIndex {
18    pub fn new() -> Self {
19        Self {
20            entries: Vec::new(),
21            lower_buf: Vec::with_capacity(256),
22        }
23    }
24
25    pub fn clear(&mut self) {
26        self.entries.clear();
27    }
28
29    pub const fn len(&self) -> usize {
30        self.entries.len()
31    }
32
33    pub const fn is_empty(&self) -> bool {
34        self.entries.is_empty()
35    }
36
37    /// Index a single entity by its name, type, and observations.
38    /// All strings must already be interned.
39    /// `entity_idx` is the position in the entity storage vec.
40    pub fn index_entity(
41        &mut self,
42        interner: &mut crate::intern::StringInterner,
43        entity_idx: u32,
44        name: StrId,
45        entity_type: StrId,
46        observations: &[StrId],
47    ) {
48        let mut texts = Vec::with_capacity(2 + observations.len());
49        texts.push(name);
50        texts.push(entity_type);
51        texts.extend_from_slice(observations);
52        self.insert_tokens(interner, entity_idx, &texts);
53    }
54
55    /// Incrementally index additional strings (e.g. newly added observations)
56    /// for an entity that is *already* indexed, without removing and rebuilding
57    /// its existing entries (P3). Token entries that already exist are deduped
58    /// during the merge, so calling this with text that overlaps existing
59    /// tokens is safe.
60    pub fn index_additional(
61        &mut self,
62        interner: &mut crate::intern::StringInterner,
63        entity_idx: u32,
64        texts: &[StrId],
65    ) {
66        self.insert_tokens(interner, entity_idx, texts);
67    }
68
69    /// Remove all entries for a given entity (before re-indexing).
70    pub fn remove_entity(&mut self, entity_idx: u32) {
71        self.entries.retain(|&(_, idx)| idx != entity_idx);
72    }
73
74    /// Search for entities whose name/type/observation tokens match `query`
75    /// case-insensitively by **prefix** (`"cof"` matches `"coffee"`).
76    ///
77    /// Note: this is an O(n) scan over every index entry. The binary-search
78    /// step below only narrows exact-token hits, but the subsequent prefix scan
79    /// already covers those (an exact match is also a prefix match), so the scan
80    /// dominates — do not read the binary search as making this sublinear.
81    pub fn search(&self, query: &str, interner: &crate::intern::StringInterner) -> Vec<u32> {
82        if query.is_empty() || self.entries.is_empty() {
83            return Vec::new();
84        }
85
86        let lower_query: String = query.to_ascii_lowercase();
87
88        // Fast path: exact token match via binary search
89        let mut matched = if let Some(token_id) = interner.get_optional(&lower_query) {
90            let range_begin = self.entries.binary_search_by(|(t, _)| t.cmp(&token_id));
91            let range_end = self.entries.binary_search_by(|(t, _)| {
92                if *t <= token_id { std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater }
93            });
94            if let (Ok(begin), Err(end)) = (range_begin, range_end) {
95                self.entries[begin..end].iter().map(|&(_, idx)| idx).collect()
96            } else {
97                Vec::new()
98            }
99        } else {
100            Vec::new()
101        };
102
103        // Prefix match scan for tokens starting with the query
104        for &(token_id, entity_idx) in &self.entries {
105            if matched.last().is_none_or(|&last| last != entity_idx) {
106                let token = interner.lookup(token_id);
107                if token.len() >= lower_query.len()
108                    && token.as_bytes().starts_with(lower_query.as_bytes())
109                {
110                    matched.push(entity_idx);
111                }
112            }
113        }
114
115        matched.sort_unstable();
116        matched.dedup();
117        matched
118    }
119
120    /// Like [`search`], but returns `(entity_idx, score)` pairs sorted by
121    /// descending score (then ascending idx for stability). `score` is the
122    /// number of indexed-token hits the entity accumulated for the query —
123    /// a cheap relevance proxy so callers can surface the best matches first.
124    ///
125    /// The scan is a single linear pass over the flat `entries` vec (no
126    /// per-entity allocation until the final compaction), keeping it
127    /// cache-friendly. A small `Vec<(idx, score)>` is gathered then sorted.
128    pub fn search_ranked(&self, query: &str, interner: &crate::intern::StringInterner) -> Vec<(u32, u32)> {
129        if query.is_empty() || self.entries.is_empty() {
130            return Vec::new();
131        }
132
133        let lower_query: String = query.to_ascii_lowercase();
134        let qbytes = lower_query.as_bytes();
135        let qlen = qbytes.len();
136
137        // Exact-token id (if the query is itself an interned token) lets us
138        // score exact hits without a string compare.
139        let exact_id = interner.get_optional(&lower_query);
140
141        // (idx, score) gathered in one pass, idx-major so equal idxs are adjacent.
142        let mut hits: Vec<(u32, u32)> = Vec::new();
143        for &(token_id, entity_idx) in &self.entries {
144            let matches = if Some(token_id) == exact_id {
145                true
146            } else {
147                let token = interner.lookup(token_id);
148                token.len() >= qlen && token.as_bytes().starts_with(qbytes)
149            };
150            if matches {
151                match hits.last_mut() {
152                    Some(last) if last.0 == entity_idx => last.1 += 1,
153                    _ => hits.push((entity_idx, 1)),
154                }
155            }
156        }
157
158        // entries are sorted by (token, idx), so a single entity_idx may appear
159        // in non-adjacent groups (once per matching token). Merge by idx, then
160        // rank by score desc.
161        hits.sort_unstable_by_key(|&(idx, _)| idx);
162        let mut merged: Vec<(u32, u32)> = Vec::with_capacity(hits.len());
163        for (idx, score) in hits {
164            match merged.last_mut() {
165                Some(last) if last.0 == idx => last.1 += score,
166                _ => merged.push((idx, score)),
167            }
168        }
169        merged.sort_unstable_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
170        merged
171    }
172
173    /// Tokenize every string in `texts`, collect the resulting
174    /// `(token, entity_idx)` entries, and merge them into the sorted `entries`
175    /// vec in a single O(N + K) pass (P2). This replaces the previous
176    /// per-token `Vec::insert`, which shifted the tail on every token and made
177    /// indexing O(K × N).
178    fn insert_tokens(
179        &mut self,
180        interner: &mut crate::intern::StringInterner,
181        entity_idx: u32,
182        texts: &[StrId],
183    ) {
184        let mut additions: Vec<(StrId, u32)> = Vec::new();
185        for &text in texts {
186            let s = interner.lookup(text);
187            if s.is_empty() {
188                continue;
189            }
190            self.lower_buf.clear();
191            self.lower_buf.extend(s.bytes().map(|b| b.to_ascii_lowercase()));
192            let lowered = unsafe { std::str::from_utf8_unchecked(&self.lower_buf) };
193            let tokens: Vec<&str> =
194                lowered.split_whitespace().filter(|t| !t.is_empty()).collect();
195            for token in tokens {
196                additions.push((interner.intern(token), entity_idx));
197            }
198        }
199        if additions.is_empty() {
200            return;
201        }
202        additions.sort_unstable();
203        additions.dedup();
204        self.merge_entries(&additions);
205    }
206
207    /// Merge a pre-sorted, deduped slice of new entries into `entries`
208    /// (also sorted and deduped) in one linear pass. Entries already present
209    /// are skipped, preserving the no-duplicate invariant.
210    fn merge_entries(&mut self, additions: &[(StrId, u32)]) {
211        let old = std::mem::take(&mut self.entries);
212        let mut merged = Vec::with_capacity(old.len() + additions.len());
213        let (mut i, mut j) = (0, 0);
214        while i < old.len() && j < additions.len() {
215            match old[i].cmp(&additions[j]) {
216                std::cmp::Ordering::Less => {
217                    merged.push(old[i]);
218                    i += 1;
219                }
220                std::cmp::Ordering::Greater => {
221                    merged.push(additions[j]);
222                    j += 1;
223                }
224                std::cmp::Ordering::Equal => {
225                    merged.push(old[i]);
226                    i += 1;
227                    j += 1;
228                }
229            }
230        }
231        merged.extend_from_slice(&old[i..]);
232        merged.extend_from_slice(&additions[j..]);
233        self.entries = merged;
234    }
235}
236
237impl Default for SearchIndex {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::intern::StringInterner;
247
248    #[test]
249    fn test_index_and_search() {
250        let mut interner = StringInterner::new();
251        let mut index = SearchIndex::new();
252
253        let alice_name = interner.intern("Alice");
254        let alice_type = interner.intern("person");
255        let alice_obs = interner.intern("likes coffee");
256
257        index.index_entity(&mut interner, 0, alice_name, alice_type, &[alice_obs]);
258
259        let bob_name = interner.intern("Bob");
260        let bob_type = interner.intern("person");
261        let bob_obs = interner.intern("drinks tea");
262
263        index.index_entity(&mut interner, 1, bob_name, bob_type, &[bob_obs]);
264
265        let results = index.search("coffee", &interner);
266        assert_eq!(results, vec![0]);
267
268        let results = index.search("person", &interner);
269        assert_eq!(results.len(), 2);
270    }
271
272    #[test]
273    fn test_remove_entity() {
274        let mut interner = StringInterner::new();
275        let mut index = SearchIndex::new();
276
277        let name = interner.intern("Test");
278        let typ = interner.intern("type");
279
280        index.index_entity(&mut interner, 0, name, typ, &[]);
281        assert!(!index.is_empty());
282
283        index.remove_entity(0);
284        assert!(index.entries.iter().all(|&(_, idx)| idx != 0));
285    }
286
287    #[test]
288    fn test_search_empty_query() {
289        let interner = StringInterner::new();
290        let index = SearchIndex::new();
291        assert!(index.search("", &interner).is_empty());
292    }
293
294    #[test]
295    fn test_search_no_match() {
296        let mut interner = StringInterner::new();
297        let mut index = SearchIndex::new();
298        let name = interner.intern("Alice");
299        let typ = interner.intern("person");
300        index.index_entity(&mut interner, 0, name, typ, &[]);
301        assert!(index.search("zzzzzz", &interner).is_empty());
302    }
303
304    #[test]
305    fn test_search_case_insensitive() {
306        let mut interner = StringInterner::new();
307        let mut index = SearchIndex::new();
308        let name = interner.intern("Alice");
309        let typ = interner.intern("person");
310        index.index_entity(&mut interner, 0, name, typ, &[]);
311        let results = index.search("ALICE", &interner);
312        assert_eq!(results, vec![0]);
313    }
314
315    #[test]
316    fn test_search_partial_substring() {
317        let mut interner = StringInterner::new();
318        let mut index = SearchIndex::new();
319        let name = interner.intern("Alice");
320        let typ = interner.intern("person");
321        index.index_entity(&mut interner, 0, name, typ, &[]);
322        let results = index.search("Ali", &interner);
323        assert_eq!(results, vec![0]);
324    }
325
326    #[test]
327    fn test_multi_token_search() {
328        let mut interner = StringInterner::new();
329        let mut index = SearchIndex::new();
330        let obs = interner.intern("likes drinking coffee");
331        let alice = interner.intern("Alice");
332        let person = interner.intern("person");
333        index.index_entity(
334            &mut interner,
335            0,
336            alice,
337            person,
338            &[obs],
339        );
340        assert_eq!(index.search("likes", &interner), vec![0]);
341        assert_eq!(index.search("drinking", &interner), vec![0]);
342        assert_eq!(index.search("coffee", &interner), vec![0]);
343    }
344
345    #[test]
346    fn test_remove_then_reindex() {
347        let mut interner = StringInterner::new();
348        let mut index = SearchIndex::new();
349        let name = interner.intern("Alice");
350        let typ = interner.intern("person");
351        index.index_entity(&mut interner, 0, name, typ, &[]);
352
353        assert_eq!(index.search("Alice", &interner).len(), 1);
354        index.remove_entity(0);
355        assert!(index.search("Alice", &interner).is_empty());
356
357        index.index_entity(&mut interner, 0, name, typ, &[]);
358        assert_eq!(index.search("Alice", &interner).len(), 1);
359    }
360
361    #[test]
362    fn test_query_longer_than_token() {
363        let mut interner = StringInterner::new();
364        let mut index = SearchIndex::new();
365        let name = interner.intern("Alice");
366        let person = interner.intern("person");
367        index.index_entity(&mut interner, 0, name, person, &[]);
368        assert!(index.search("AliceInWonderland", &interner).is_empty());
369    }
370
371    #[test]
372    fn test_empty_index() {
373        let interner = StringInterner::new();
374        let index = SearchIndex::new();
375        assert!(index.search("anything", &interner).is_empty());
376    }
377
378    #[test]
379    fn test_search_ranked_orders_by_score() {
380        let mut interner = StringInterner::new();
381        let mut index = SearchIndex::new();
382        // Entity 0: "coffee" appears in both name and observation → score 2.
383        let n0 = interner.intern("coffee");
384        let t0 = interner.intern("drink");
385        let o0 = interner.intern("coffee beans");
386        index.index_entity(&mut interner, 0, n0, t0, &[o0]);
387        // Entity 1: "coffee" appears once (observation only) → score 1.
388        let n1 = interner.intern("Bob");
389        let t1 = interner.intern("person");
390        let o1 = interner.intern("likes coffee");
391        index.index_entity(&mut interner, 1, n1, t1, &[o1]);
392
393        let ranked = index.search_ranked("coffee", &interner);
394        assert_eq!(ranked.len(), 2);
395        // Higher score first.
396        assert_eq!(ranked[0].0, 0);
397        assert!(ranked[0].1 >= ranked[1].1);
398        assert_eq!(ranked[1].0, 1);
399    }
400
401    #[test]
402    fn test_search_ranked_empty_query() {
403        let interner = StringInterner::new();
404        let index = SearchIndex::new();
405        assert!(index.search_ranked("", &interner).is_empty());
406    }
407
408    #[test]
409    fn test_search_is_prefix_not_substring() {
410        let mut interner = StringInterner::new();
411        let mut index = SearchIndex::new();
412        let name = interner.intern("coffee");
413        let typ = interner.intern("drink");
414        index.index_entity(&mut interner, 0, name, typ, &[]);
415        // Prefix of a token matches.
416        assert_eq!(index.search("cof", &interner), vec![0]);
417        assert_eq!(index.search("coffee", &interner), vec![0]);
418        // Interior substrings do NOT match — this documents real behavior, not
419        // the "substring search" the docs once claimed.
420        assert!(index.search("ffee", &interner).is_empty());
421        assert!(index.search("offe", &interner).is_empty());
422    }
423
424    #[test]
425    fn test_search_ranked_is_prefix_not_substring() {
426        let mut interner = StringInterner::new();
427        let mut index = SearchIndex::new();
428        let name = interner.intern("coffee");
429        let typ = interner.intern("drink");
430        index.index_entity(&mut interner, 0, name, typ, &[]);
431        assert!(!index.search_ranked("cof", &interner).is_empty());
432        assert!(index.search_ranked("ffee", &interner).is_empty());
433    }
434
435    #[test]
436    fn test_clear_index() {
437        let mut interner = StringInterner::new();
438        let mut index = SearchIndex::new();
439        let name = interner.intern("Alice");
440        let person = interner.intern("person");
441        index.index_entity(&mut interner, 0, name, person, &[]);
442        assert!(!index.is_empty());
443        index.clear();
444        assert!(index.is_empty());
445        assert!(index.search("Alice", &interner).is_empty());
446    }
447}