Skip to main content

mcp_memory/
search.rs

1use crate::intern::StrId;
2
3/// Inverted word index for fast entity search.
4///
5/// For each entity, we tokenize its name, type, and observations,
6/// store each token → set of matching entity indices.
7///
8/// Uses a flat `Vec<(StrId, u32)>` sorted by (token, entity_idx)
9/// for cache-friendly lookups via binary search.
10pub struct SearchIndex {
11    // Sorted by (token, entity_idx), no duplicates.
12    entries: Vec<(StrId, u32)>,
13    lower_buf: Vec<u8>,
14}
15
16impl SearchIndex {
17    pub fn new() -> Self {
18        Self {
19            entries: Vec::new(),
20            lower_buf: Vec::with_capacity(256),
21        }
22    }
23
24    pub fn clear(&mut self) {
25        self.entries.clear();
26    }
27
28    pub const fn len(&self) -> usize {
29        self.entries.len()
30    }
31
32    pub const fn is_empty(&self) -> bool {
33        self.entries.is_empty()
34    }
35
36    /// Index a single entity by its name, type, and observations.
37    /// All strings must already be interned.
38    /// `entity_idx` is the position in the entity storage vec.
39    pub fn index_entity(
40        &mut self,
41        interner: &mut crate::intern::StringInterner,
42        entity_idx: u32,
43        name: StrId,
44        entity_type: StrId,
45        observations: &[StrId],
46    ) {
47        let mut texts = Vec::with_capacity(2 + observations.len());
48        texts.push(name);
49        texts.push(entity_type);
50        texts.extend_from_slice(observations);
51        self.insert_tokens(interner, entity_idx, &texts);
52    }
53
54    /// Incrementally index additional strings (e.g. newly added observations)
55    /// for an entity that is *already* indexed, without removing and rebuilding
56    /// its existing entries (P3). Token entries that already exist are deduped
57    /// during the merge, so calling this with text that overlaps existing
58    /// tokens is safe.
59    pub fn index_additional(
60        &mut self,
61        interner: &mut crate::intern::StringInterner,
62        entity_idx: u32,
63        texts: &[StrId],
64    ) {
65        self.insert_tokens(interner, entity_idx, texts);
66    }
67
68    /// Remove all entries for a given entity (before re-indexing).
69    pub fn remove_entity(&mut self, entity_idx: u32) {
70        self.entries.retain(|&(_, idx)| idx != entity_idx);
71    }
72
73    /// Search for entities matching the query.
74    /// Uses binary search for exact token matches (O(log n + matches))
75    /// and falls back to prefix matching for partial queries.
76    pub fn search(&self, query: &str, interner: &crate::intern::StringInterner) -> Vec<u32> {
77        if query.is_empty() || self.entries.is_empty() {
78            return Vec::new();
79        }
80
81        let lower_query: String = query.to_ascii_lowercase();
82
83        // Fast path: exact token match via binary search
84        let mut matched = if let Some(token_id) = interner.get_optional(&lower_query) {
85            let range_begin = self.entries.binary_search_by(|(t, _)| t.cmp(&token_id));
86            let range_end = self.entries.binary_search_by(|(t, _)| {
87                if *t <= token_id { std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater }
88            });
89            if let (Ok(begin), Err(end)) = (range_begin, range_end) {
90                self.entries[begin..end].iter().map(|&(_, idx)| idx).collect()
91            } else {
92                Vec::new()
93            }
94        } else {
95            Vec::new()
96        };
97
98        // Prefix match scan for tokens starting with the query
99        for &(token_id, entity_idx) in &self.entries {
100            if matched.last().is_none_or(|&last| last != entity_idx) {
101                let token = interner.lookup(token_id);
102                if token.len() >= lower_query.len()
103                    && token.as_bytes().starts_with(lower_query.as_bytes())
104                {
105                    matched.push(entity_idx);
106                }
107            }
108        }
109
110        matched.sort_unstable();
111        matched.dedup();
112        matched
113    }
114
115    /// Like [`search`], but returns `(entity_idx, score)` pairs sorted by
116    /// descending score (then ascending idx for stability). `score` is the
117    /// number of indexed-token hits the entity accumulated for the query —
118    /// a cheap relevance proxy so callers can surface the best matches first.
119    ///
120    /// The scan is a single linear pass over the flat `entries` vec (no
121    /// per-entity allocation until the final compaction), keeping it
122    /// cache-friendly. A small `Vec<(idx, score)>` is gathered then sorted.
123    pub fn search_ranked(&self, query: &str, interner: &crate::intern::StringInterner) -> Vec<(u32, u32)> {
124        if query.is_empty() || self.entries.is_empty() {
125            return Vec::new();
126        }
127
128        let lower_query: String = query.to_ascii_lowercase();
129        let qbytes = lower_query.as_bytes();
130        let qlen = qbytes.len();
131
132        // Exact-token id (if the query is itself an interned token) lets us
133        // score exact hits without a string compare.
134        let exact_id = interner.get_optional(&lower_query);
135
136        // (idx, score) gathered in one pass, idx-major so equal idxs are adjacent.
137        let mut hits: Vec<(u32, u32)> = Vec::new();
138        for &(token_id, entity_idx) in &self.entries {
139            let matches = if Some(token_id) == exact_id {
140                true
141            } else {
142                let token = interner.lookup(token_id);
143                token.len() >= qlen && token.as_bytes().starts_with(qbytes)
144            };
145            if matches {
146                match hits.last_mut() {
147                    Some(last) if last.0 == entity_idx => last.1 += 1,
148                    _ => hits.push((entity_idx, 1)),
149                }
150            }
151        }
152
153        // entries are sorted by (token, idx), so a single entity_idx may appear
154        // in non-adjacent groups (once per matching token). Merge by idx, then
155        // rank by score desc.
156        hits.sort_unstable_by_key(|&(idx, _)| idx);
157        let mut merged: Vec<(u32, u32)> = Vec::with_capacity(hits.len());
158        for (idx, score) in hits {
159            match merged.last_mut() {
160                Some(last) if last.0 == idx => last.1 += score,
161                _ => merged.push((idx, score)),
162            }
163        }
164        merged.sort_unstable_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
165        merged
166    }
167
168    /// Tokenize every string in `texts`, collect the resulting
169    /// `(token, entity_idx)` entries, and merge them into the sorted `entries`
170    /// vec in a single O(N + K) pass (P2). This replaces the previous
171    /// per-token `Vec::insert`, which shifted the tail on every token and made
172    /// indexing O(K × N).
173    fn insert_tokens(
174        &mut self,
175        interner: &mut crate::intern::StringInterner,
176        entity_idx: u32,
177        texts: &[StrId],
178    ) {
179        let mut additions: Vec<(StrId, u32)> = Vec::new();
180        for &text in texts {
181            let s = interner.lookup(text);
182            if s.is_empty() {
183                continue;
184            }
185            self.lower_buf.clear();
186            self.lower_buf.extend(s.bytes().map(|b| b.to_ascii_lowercase()));
187            let lowered = unsafe { std::str::from_utf8_unchecked(&self.lower_buf) };
188            let tokens: Vec<&str> =
189                lowered.split_whitespace().filter(|t| !t.is_empty()).collect();
190            for token in tokens {
191                additions.push((interner.intern(token), entity_idx));
192            }
193        }
194        if additions.is_empty() {
195            return;
196        }
197        additions.sort_unstable();
198        additions.dedup();
199        self.merge_entries(&additions);
200    }
201
202    /// Merge a pre-sorted, deduped slice of new entries into `entries`
203    /// (also sorted and deduped) in one linear pass. Entries already present
204    /// are skipped, preserving the no-duplicate invariant.
205    fn merge_entries(&mut self, additions: &[(StrId, u32)]) {
206        let old = std::mem::take(&mut self.entries);
207        let mut merged = Vec::with_capacity(old.len() + additions.len());
208        let (mut i, mut j) = (0, 0);
209        while i < old.len() && j < additions.len() {
210            match old[i].cmp(&additions[j]) {
211                std::cmp::Ordering::Less => {
212                    merged.push(old[i]);
213                    i += 1;
214                }
215                std::cmp::Ordering::Greater => {
216                    merged.push(additions[j]);
217                    j += 1;
218                }
219                std::cmp::Ordering::Equal => {
220                    merged.push(old[i]);
221                    i += 1;
222                    j += 1;
223                }
224            }
225        }
226        merged.extend_from_slice(&old[i..]);
227        merged.extend_from_slice(&additions[j..]);
228        self.entries = merged;
229    }
230}
231
232impl Default for SearchIndex {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::intern::StringInterner;
242
243    #[test]
244    fn test_index_and_search() {
245        let mut interner = StringInterner::new();
246        let mut index = SearchIndex::new();
247
248        let alice_name = interner.intern("Alice");
249        let alice_type = interner.intern("person");
250        let alice_obs = interner.intern("likes coffee");
251
252        index.index_entity(&mut interner, 0, alice_name, alice_type, &[alice_obs]);
253
254        let bob_name = interner.intern("Bob");
255        let bob_type = interner.intern("person");
256        let bob_obs = interner.intern("drinks tea");
257
258        index.index_entity(&mut interner, 1, bob_name, bob_type, &[bob_obs]);
259
260        let results = index.search("coffee", &interner);
261        assert_eq!(results, vec![0]);
262
263        let results = index.search("person", &interner);
264        assert_eq!(results.len(), 2);
265    }
266
267    #[test]
268    fn test_remove_entity() {
269        let mut interner = StringInterner::new();
270        let mut index = SearchIndex::new();
271
272        let name = interner.intern("Test");
273        let typ = interner.intern("type");
274
275        index.index_entity(&mut interner, 0, name, typ, &[]);
276        assert!(!index.is_empty());
277
278        index.remove_entity(0);
279        assert!(index.entries.iter().all(|&(_, idx)| idx != 0));
280    }
281
282    #[test]
283    fn test_search_empty_query() {
284        let interner = StringInterner::new();
285        let index = SearchIndex::new();
286        assert!(index.search("", &interner).is_empty());
287    }
288
289    #[test]
290    fn test_search_no_match() {
291        let mut interner = StringInterner::new();
292        let mut index = SearchIndex::new();
293        let name = interner.intern("Alice");
294        let typ = interner.intern("person");
295        index.index_entity(&mut interner, 0, name, typ, &[]);
296        assert!(index.search("zzzzzz", &interner).is_empty());
297    }
298
299    #[test]
300    fn test_search_case_insensitive() {
301        let mut interner = StringInterner::new();
302        let mut index = SearchIndex::new();
303        let name = interner.intern("Alice");
304        let typ = interner.intern("person");
305        index.index_entity(&mut interner, 0, name, typ, &[]);
306        let results = index.search("ALICE", &interner);
307        assert_eq!(results, vec![0]);
308    }
309
310    #[test]
311    fn test_search_partial_substring() {
312        let mut interner = StringInterner::new();
313        let mut index = SearchIndex::new();
314        let name = interner.intern("Alice");
315        let typ = interner.intern("person");
316        index.index_entity(&mut interner, 0, name, typ, &[]);
317        let results = index.search("Ali", &interner);
318        assert_eq!(results, vec![0]);
319    }
320
321    #[test]
322    fn test_multi_token_search() {
323        let mut interner = StringInterner::new();
324        let mut index = SearchIndex::new();
325        let obs = interner.intern("likes drinking coffee");
326        let alice = interner.intern("Alice");
327        let person = interner.intern("person");
328        index.index_entity(
329            &mut interner,
330            0,
331            alice,
332            person,
333            &[obs],
334        );
335        assert_eq!(index.search("likes", &interner), vec![0]);
336        assert_eq!(index.search("drinking", &interner), vec![0]);
337        assert_eq!(index.search("coffee", &interner), vec![0]);
338    }
339
340    #[test]
341    fn test_remove_then_reindex() {
342        let mut interner = StringInterner::new();
343        let mut index = SearchIndex::new();
344        let name = interner.intern("Alice");
345        let typ = interner.intern("person");
346        index.index_entity(&mut interner, 0, name, typ, &[]);
347
348        assert_eq!(index.search("Alice", &interner).len(), 1);
349        index.remove_entity(0);
350        assert!(index.search("Alice", &interner).is_empty());
351
352        index.index_entity(&mut interner, 0, name, typ, &[]);
353        assert_eq!(index.search("Alice", &interner).len(), 1);
354    }
355
356    #[test]
357    fn test_query_longer_than_token() {
358        let mut interner = StringInterner::new();
359        let mut index = SearchIndex::new();
360        let name = interner.intern("Alice");
361        let person = interner.intern("person");
362        index.index_entity(&mut interner, 0, name, person, &[]);
363        assert!(index.search("AliceInWonderland", &interner).is_empty());
364    }
365
366    #[test]
367    fn test_empty_index() {
368        let interner = StringInterner::new();
369        let index = SearchIndex::new();
370        assert!(index.search("anything", &interner).is_empty());
371    }
372
373    #[test]
374    fn test_search_ranked_orders_by_score() {
375        let mut interner = StringInterner::new();
376        let mut index = SearchIndex::new();
377        // Entity 0: "coffee" appears in both name and observation → score 2.
378        let n0 = interner.intern("coffee");
379        let t0 = interner.intern("drink");
380        let o0 = interner.intern("coffee beans");
381        index.index_entity(&mut interner, 0, n0, t0, &[o0]);
382        // Entity 1: "coffee" appears once (observation only) → score 1.
383        let n1 = interner.intern("Bob");
384        let t1 = interner.intern("person");
385        let o1 = interner.intern("likes coffee");
386        index.index_entity(&mut interner, 1, n1, t1, &[o1]);
387
388        let ranked = index.search_ranked("coffee", &interner);
389        assert_eq!(ranked.len(), 2);
390        // Higher score first.
391        assert_eq!(ranked[0].0, 0);
392        assert!(ranked[0].1 >= ranked[1].1);
393        assert_eq!(ranked[1].0, 1);
394    }
395
396    #[test]
397    fn test_search_ranked_empty_query() {
398        let interner = StringInterner::new();
399        let index = SearchIndex::new();
400        assert!(index.search_ranked("", &interner).is_empty());
401    }
402
403    #[test]
404    fn test_clear_index() {
405        let mut interner = StringInterner::new();
406        let mut index = SearchIndex::new();
407        let name = interner.intern("Alice");
408        let person = interner.intern("person");
409        index.index_entity(&mut interner, 0, name, person, &[]);
410        assert!(!index.is_empty());
411        index.clear();
412        assert!(index.is_empty());
413        assert!(index.search("Alice", &interner).is_empty());
414    }
415}