Skip to main content

mcp_memory/
intern.rs

1use ahash::RandomState;
2use std::fmt;
3
4// Ctrl-byte sentinels for the dedup hash table.
5const EMPTY: u8 = 0xFF;
6// Stored h2 values are 0x00-0x7F — never collide with 0xFF.
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9#[repr(transparent)]
10pub struct StrId(u32);
11
12impl StrId {
13    pub const EMPTY: StrId = StrId(u32::MAX);
14
15    #[inline]
16    pub const fn is_empty(self) -> bool {
17        self.0 == u32::MAX
18    }
19
20    #[inline]
21    pub const fn as_u32(self) -> u32 {
22        self.0
23    }
24}
25
26impl fmt::Display for StrId {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        if self.is_empty() {
29            write!(f, "<empty>")
30        } else {
31            write!(f, "StrId({})", self.0)
32        }
33    }
34}
35
36/// A 7-bit hash stamp extracted from the full 64-bit hash.
37/// Always has bit 7 clear, so it never collides with `EMPTY` (0xFF).
38#[inline(always)]
39const fn h2(hash: u64) -> u8 {
40    (hash & 0x7F) as u8
41}
42
43/// Starting bucket index derived from the upper bits of the hash.
44#[inline(always)]
45const fn h1(hash: u64, mask: usize) -> usize {
46    ((hash >> 7) as usize) & mask
47}
48
49pub struct StringInterner {
50    arena: String,
51    offsets: Vec<u32>,
52    hashes: Vec<u64>,
53    // Dedup hash table — ctrl-byte buckets with parallel arrays.
54    ctrl: Vec<u8>,
55    table_hashes: Vec<u64>,
56    table_ids: Vec<StrId>,
57    table_mask: usize,
58    count: usize,
59    hasher: RandomState,
60}
61
62impl StringInterner {
63    pub fn new() -> Self {
64        const CAP: usize = 256;
65        Self {
66            arena: String::with_capacity(8192),
67            offsets: vec![0],
68            hashes: Vec::with_capacity(128),
69            ctrl: vec![EMPTY; CAP],
70            table_hashes: vec![0; CAP],
71            table_ids: vec![StrId::EMPTY; CAP],
72            table_mask: CAP - 1,
73            count: 0,
74            hasher: RandomState::new(),
75        }
76    }
77
78    pub fn with_capacity(string_capacity: usize, estimated_strings: usize) -> Self {
79        let cap = estimated_strings.next_power_of_two().max(64);
80        Self {
81            arena: String::with_capacity(string_capacity),
82            offsets: vec![0],
83            hashes: Vec::with_capacity(estimated_strings),
84            ctrl: vec![EMPTY; cap],
85            table_hashes: vec![0; cap],
86            table_ids: vec![StrId::EMPTY; cap],
87            table_mask: cap - 1,
88            count: 0,
89            hasher: RandomState::new(),
90        }
91    }
92
93    #[inline]
94    pub fn intern(&mut self, s: &str) -> StrId {
95        if s.is_empty() {
96            return StrId::EMPTY;
97        }
98        let hash = self.hasher.hash_one(s);
99        let stamp = h2(hash);
100        let mask = self.table_mask;
101        let mut idx = h1(hash, mask);
102
103        loop {
104            let c = &self.ctrl[idx];
105            if *c & 0x80 != 0 {
106                // Empty slot → insert new string.
107                let id = self.offsets.len() as u32 - 1;
108                self.arena.push_str(s);
109                self.offsets.push(self.arena.len() as u32);
110                self.hashes.push(hash);
111                self.ctrl[idx] = stamp;
112                self.table_hashes[idx] = hash;
113                self.table_ids[idx] = StrId(id);
114                self.count += 1;
115                if self.count * 4 > self.ctrl.len() * 3 {
116                    self.grow();
117                }
118                return StrId(id);
119            }
120            if *c == stamp && self.table_hashes[idx] == hash {
121                let existing = self.table_ids[idx].0;
122                let start = self.offsets[existing as usize] as usize;
123                let end = self.offsets[existing as usize + 1] as usize;
124                let existing_str = unsafe { self.arena.get_unchecked(start..end) };
125                if existing_str == s {
126                    return StrId(existing);
127                }
128            }
129            idx = (idx + 1) & mask;
130        }
131    }
132
133    /// Look up a string in the dedup table without inserting.
134    /// Returns `StrId` if the string already exists, `None` otherwise.
135    #[inline]
136    pub fn get_optional(&self, s: &str) -> Option<StrId> {
137        if s.is_empty() {
138            return Some(StrId::EMPTY);
139        }
140        let hash = self.hasher.hash_one(s);
141        let stamp = h2(hash);
142        let mask = self.table_mask;
143        let mut idx = h1(hash, mask);
144
145        for _ in 0..self.ctrl.len() {
146            let c = self.ctrl[idx];
147            if c & 0x80 != 0 {
148                return None;
149            }
150            if c == stamp && self.table_hashes[idx] == hash {
151                let existing = self.table_ids[idx].0;
152                let start = self.offsets[existing as usize] as usize;
153                let end = self.offsets[existing as usize + 1] as usize;
154                if unsafe { self.arena.get_unchecked(start..end) == s } {
155                    return Some(StrId(existing));
156                }
157            }
158            idx = (idx + 1) & mask;
159        }
160        None
161    }
162
163    #[inline]
164    pub fn lookup(&self, id: StrId) -> &str {
165        if id.is_empty() {
166            return "";
167        }
168        let start = self.offsets[id.0 as usize] as usize;
169        let end = self.offsets[id.0 as usize + 1] as usize;
170        unsafe { self.arena.get_unchecked(start..end) }
171    }
172
173    #[inline]
174    pub fn get_hash(&self, id: StrId) -> u64 {
175        if id.is_empty() {
176            return 0;
177        }
178        self.hashes[id.0 as usize]
179    }
180
181    pub const fn len(&self) -> usize {
182        self.offsets.len() - 1
183    }
184
185    pub const fn is_empty(&self) -> bool {
186        self.len() == 0
187    }
188
189    pub const fn total_bytes(&self) -> usize {
190        self.arena.len()
191    }
192
193    fn grow(&mut self) {
194        let new_size = self.ctrl.len() * 2;
195        let new_mask = new_size - 1;
196        let mut new_ctrl = vec![EMPTY; new_size];
197        let mut new_hashes = vec![0u64; new_size];
198        let mut new_ids = vec![StrId::EMPTY; new_size];
199
200        for i in 0..self.ctrl.len() {
201            if self.ctrl[i] & 0x80 == 0 {
202                let hash = self.table_hashes[i];
203                let stamp = h2(hash);
204                let mut idx = h1(hash, new_mask);
205                while new_ctrl[idx] & 0x80 == 0 {
206                    idx = (idx + 1) & new_mask;
207                }
208                new_ctrl[idx] = stamp;
209                new_hashes[idx] = hash;
210                new_ids[idx] = self.table_ids[i];
211            }
212        }
213
214        self.ctrl = new_ctrl;
215        self.table_hashes = new_hashes;
216        self.table_ids = new_ids;
217        self.table_mask = new_mask;
218    }
219}
220
221impl Default for StringInterner {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_intern_empty() {
233        let mut interner = StringInterner::new();
234        assert!(interner.intern("").is_empty());
235    }
236
237    #[test]
238    fn test_intern_dedup() {
239        let mut interner = StringInterner::new();
240        let a = interner.intern("hello");
241        let b = interner.intern("hello");
242        assert_eq!(a, b);
243    }
244
245    #[test]
246    fn test_intern_unique() {
247        let mut interner = StringInterner::new();
248        let a = interner.intern("hello");
249        let b = interner.intern("world");
250        assert_ne!(a, b);
251    }
252
253    #[test]
254    fn test_lookup() {
255        let mut interner = StringInterner::new();
256        let id = interner.intern("hello world");
257        assert_eq!(interner.lookup(id), "hello world");
258    }
259
260    #[test]
261    fn test_large_intern() {
262        let mut interner = StringInterner::new();
263        let mut ids = Vec::new();
264        for i in 0..1000 {
265            let s = format!("string_{i}");
266            ids.push(interner.intern(&s));
267        }
268        for (i, &id) in ids.iter().enumerate() {
269            let expected = format!("string_{i}");
270            assert_eq!(interner.lookup(id), expected);
271        }
272        assert_eq!(interner.len(), 1000);
273    }
274
275    #[test]
276    fn test_lookup_empty_id() {
277        let interner = StringInterner::new();
278        assert_eq!(interner.lookup(StrId::EMPTY), "");
279    }
280
281    #[test]
282    fn test_get_hash_empty_id() {
283        let interner = StringInterner::new();
284        assert_eq!(interner.get_hash(StrId::EMPTY), 0);
285    }
286
287    #[test]
288    fn test_get_hash_consistency() {
289        let mut interner = StringInterner::new();
290        let id = interner.intern("consistent");
291        let hash1 = interner.get_hash(id);
292        let id2 = interner.intern("consistent");
293        assert_eq!(id, id2);
294        let hash2 = interner.get_hash(id2);
295        assert_eq!(hash1, hash2);
296    }
297
298    #[test]
299    fn test_total_bytes() {
300        let mut interner = StringInterner::new();
301        assert_eq!(interner.total_bytes(), 0);
302        interner.intern("abc");
303        assert_eq!(interner.total_bytes(), 3);
304        interner.intern("defg");
305        assert_eq!(interner.total_bytes(), 7);
306        interner.intern("abc"); // dedup, no new bytes
307        assert_eq!(interner.total_bytes(), 7);
308    }
309
310    #[test]
311    fn test_intern_empty_via_lookup() {
312        let mut interner = StringInterner::new();
313        let e = interner.intern("");
314        assert!(e.is_empty());
315        // Interning empty again should still return EMPTY.
316        let e2 = interner.intern("");
317        assert!(e2.is_empty());
318    }
319
320    #[test]
321    fn test_grow_triggers() {
322        let mut interner = StringInterner::with_capacity(4096, 16);
323        // Insert enough strings to force a grow.
324        for i in 0..100 {
325            interner.intern(&format!("grow_test_{i}"));
326        }
327        assert_eq!(interner.len(), 100);
328        // Verify all strings are still accessible.
329        for i in 0..100 {
330            let id = interner.intern(&format!("grow_test_{i}"));
331            assert_eq!(interner.lookup(id), format!("grow_test_{i}"));
332        }
333    }
334
335    #[test]
336    fn test_many_dedup_same_string() {
337        let mut interner = StringInterner::new();
338        let id = interner.intern("same");
339        for _ in 0..1000 {
340            let new_id = interner.intern("same");
341            assert_eq!(new_id, id);
342        }
343        assert_eq!(interner.len(), 1);
344    }
345
346    #[test]
347    fn test_interner_with_capacity() {
348        let mut interner = StringInterner::with_capacity(1024, 50);
349        assert_eq!(interner.len(), 0);
350        for i in 0..50 {
351            interner.intern(&format!("cap_test_{i}"));
352        }
353        assert_eq!(interner.len(), 50);
354    }
355
356    #[test]
357    fn test_case_sensitive_dedup() {
358        let mut interner = StringInterner::new();
359        let a = interner.intern("Hello");
360        let b = interner.intern("hello");
361        assert_ne!(a, b); // case-sensitive dedup
362    }
363
364    #[test]
365    fn test_default_impl() {
366        let mut interner: StringInterner = Default::default();
367        let id = interner.intern("default");
368        assert_eq!(interner.lookup(id), "default");
369    }
370
371    #[test]
372    fn test_is_empty_method() {
373        let mut interner = StringInterner::new();
374        assert!(interner.is_empty());
375        interner.intern("x");
376        assert!(!interner.is_empty());
377    }
378}