cursed_collections/
symbol_table.rs

1use ::alloc::{alloc, string::String, vec};
2use core::borrow::Borrow;
3use core::hash::{Hash, Hasher};
4use core::{cell, fmt, hash, marker, mem, ptr, slice, str};
5use hashbrown::HashSet;
6
7const LARGE_SYMBOL_THRESHOLD: usize = 1 << 9;
8const SEGMENT_CAPACITY: usize = 1 << 12;
9
10#[allow(clippy::assertions_on_constants)]
11const _: () = assert!(
12    LARGE_SYMBOL_THRESHOLD < SEGMENT_CAPACITY,
13    "a small symbol must always fit in a fresh segment",
14);
15
16#[repr(transparent)]
17struct SymbolKey(*const str);
18
19impl PartialEq for SymbolKey {
20    fn eq(&self, other: &Self) -> bool {
21        unsafe { *self.0 == *other.0 }
22    }
23}
24
25impl Eq for SymbolKey {}
26
27impl hash::Hash for SymbolKey {
28    fn hash<H: hash::Hasher>(&self, state: &mut H) {
29        unsafe { (*self.0).hash(state) }
30    }
31}
32
33/// Like a `&str`, but with constant time equality comparison.
34///
35/// It is a distinct type from `&str` to avoid confusion where an interned string could be compared
36/// to an uninterned string and give a confusing false negative.
37#[repr(transparent)]
38#[derive(Copy, Clone)]
39pub struct Symbol<'table> {
40    ptr: *const str,
41    _p: marker::PhantomData<&'table str>,
42}
43
44impl<'table> Symbol<'table> {
45    fn new(ptr: *const str) -> Self {
46        Self {
47            ptr,
48            _p: marker::PhantomData,
49        }
50    }
51
52    pub fn as_str(self) -> &'table str {
53        unsafe { &*self.ptr }
54    }
55}
56
57impl<'table> PartialEq for Symbol<'table> {
58    fn eq(&self, other: &Self) -> bool {
59        ptr::eq(self.ptr, other.ptr)
60    }
61}
62
63impl<'table> Eq for Symbol<'table> {}
64
65impl<'table> Hash for Symbol<'table> {
66    fn hash<H: Hasher>(&self, state: &mut H) {
67        self.as_str().hash(state)
68    }
69}
70
71impl<'table> AsRef<str> for Symbol<'table> {
72    fn as_ref(&self) -> &str {
73        self.as_str()
74    }
75}
76
77impl<'table> fmt::Debug for Symbol<'table> {
78    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79        write!(f, "{:?}@{:p}", self.as_str(), self.ptr)
80    }
81}
82
83impl<'table> fmt::Display for Symbol<'table> {
84    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85        f.write_str(self.as_str())
86    }
87}
88
89impl<'table> Borrow<str> for Symbol<'table> {
90    fn borrow(&self) -> &str {
91        self.as_ref()
92    }
93}
94
95const BUFFER_LAYOUT: alloc::Layout = alloc::Layout::new::<[u8; SEGMENT_CAPACITY]>();
96
97/// A set of strings. Unlike a regular set, strings are stored contiguously in pages to reduce
98/// memory usage.
99pub struct SymbolTable {
100    lookup: cell::UnsafeCell<HashSet<SymbolKey>>,
101    small_symbols: cell::UnsafeCell<vec::Vec<*const u8>>,
102    large_symbols: cell::UnsafeCell<vec::Vec<(*const u8, usize, usize)>>,
103    tail: cell::Cell<*mut u8>,
104    tail_offset: cell::Cell<usize>,
105}
106
107impl SymbolTable {
108    /// Create an empty table.
109    ///
110    /// Unlike many types in `alloc`, this allocates right away.
111    pub fn new() -> Self {
112        unsafe {
113            Self {
114                lookup: cell::UnsafeCell::new(HashSet::new()),
115                small_symbols: cell::UnsafeCell::new(vec![]),
116                large_symbols: cell::UnsafeCell::new(vec![]),
117                tail: cell::Cell::new(alloc::alloc(BUFFER_LAYOUT)),
118                tail_offset: cell::Cell::new(0),
119            }
120        }
121    }
122
123    /// Adds a symbol to the table if it does not exist.
124    ///
125    /// # Example
126    ///
127    /// ```
128    /// # use cursed_collections::SymbolTable;
129    /// let table = SymbolTable::new();
130    /// assert_eq!(table.intern("my symbol"), table.intern("my symbol"));
131    /// ```
132    pub fn intern(&self, text: impl Into<String> + AsRef<str>) -> Symbol {
133        unsafe {
134            let lookup = &mut *self.lookup.get();
135            if let Some(&SymbolKey(ptr)) = lookup.get(&SymbolKey(text.as_ref())) {
136                return Symbol::new(ptr);
137            }
138
139            let symbol @ Symbol { ptr, .. } = self.gensym(text);
140            lookup.insert(SymbolKey(ptr));
141            symbol
142        }
143    }
144
145    /// Adds a symbol to the table. This symbol is always considered distinct from all other symbols
146    /// even if they are textually identical.
147    ///
148    /// # Example
149    ///
150    /// ```
151    /// # use cursed_collections::SymbolTable;
152    /// let table = SymbolTable::new();
153    /// assert_ne!(table.intern("my symbol"), table.gensym("my symbol"));
154    /// ```
155    ///
156    /// # Name
157    ///
158    /// The name "`gensym`" is common within the Lisp family of languages where symbols are built in
159    /// the language itself.
160    pub fn gensym(&self, text: impl Into<String> + AsRef<str>) -> Symbol {
161        unsafe {
162            let text_len = text.as_ref().len();
163            if text_len >= LARGE_SYMBOL_THRESHOLD {
164                let large_symbol = mem::ManuallyDrop::new(text.into());
165                let ptr = large_symbol.as_ptr();
166                let size = large_symbol.len();
167                (*self.large_symbols.get()).push((ptr, size, large_symbol.capacity()));
168                return Symbol::new(str::from_utf8_unchecked(slice::from_raw_parts(ptr, size)));
169            }
170
171            if text_len + self.tail_offset.get() > SEGMENT_CAPACITY {
172                self.tail_offset.set(0);
173                let prev_tail = self.tail.replace(alloc::alloc(BUFFER_LAYOUT));
174                (*self.small_symbols.get()).push(prev_tail);
175            }
176
177            let tail_offset = self.tail_offset.get();
178            let dst = self.tail.get().add(tail_offset);
179            ptr::copy_nonoverlapping(text.as_ref().as_ptr(), dst, text_len);
180            self.tail_offset.replace(tail_offset + text_len);
181            Symbol::new(str::from_utf8_unchecked(slice::from_raw_parts(
182                dst, text_len,
183            )))
184        }
185    }
186}
187
188impl Drop for SymbolTable {
189    fn drop(&mut self) {
190        unsafe {
191            alloc::dealloc(self.tail.get(), BUFFER_LAYOUT);
192            for segment in self.small_symbols.get_mut().drain(..) {
193                alloc::dealloc(segment as *mut _, BUFFER_LAYOUT);
194            }
195            for (ptr, size, capacity) in self.large_symbols.get_mut().drain(..) {
196                String::from_raw_parts(ptr as *mut _, size, capacity);
197            }
198        }
199    }
200}
201
202impl Default for SymbolTable {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::{Symbol, SymbolTable, LARGE_SYMBOL_THRESHOLD};
211    use quickcheck_macros::quickcheck;
212    use std::{iter, ptr};
213
214    #[test]
215    fn two_symbols_are_different() {
216        let table = SymbolTable::new();
217        assert_ne!(table.intern("laura"), table.intern("maddy"));
218    }
219
220    #[test]
221    fn empty_symbol_is_different_from_other_symbols() {
222        {
223            let table = SymbolTable::new();
224            assert_ne!(table.intern(""), table.intern("laura"));
225        }
226        {
227            let table = SymbolTable::new();
228            assert_ne!(table.intern("laura"), table.intern(""));
229        }
230    }
231
232    #[test]
233    fn interning_a_single_null_byte_works() {
234        let table = SymbolTable::new();
235        assert_eq!(table.intern("\0"), table.intern("\0"));
236    }
237
238    #[test]
239    fn interning_a_large_string() {
240        let text = iter::repeat('a')
241            .take(2 * LARGE_SYMBOL_THRESHOLD + 7)
242            .collect::<String>();
243        let table = SymbolTable::new();
244        assert_eq!(table.intern(&text), table.intern(text));
245    }
246
247    #[test]
248    fn interning_can_refer_to_previous_segment() {
249        let table = SymbolTable::new();
250        let symbol = table.intern("laura");
251        for c in 'a'..'z' {
252            table.intern(iter::repeat(c).take(234).collect::<String>());
253        }
254        assert_eq!(symbol, table.intern("laura"));
255    }
256
257    #[quickcheck]
258    #[cfg_attr(miri, ignore)]
259    fn interning_twice_returns_same_symbol(texts: Vec<String>) -> bool {
260        let table = SymbolTable::new();
261        let symbols = texts
262            .iter()
263            .map(|text| table.intern(text))
264            .collect::<Vec<_>>();
265        symbols.into_iter().zip(texts.into_iter()).into_iter().all(
266            |(Symbol { ptr: expected, .. }, text)| {
267                let Symbol { ptr: actual, .. } = table.intern(text);
268                ptr::eq(expected, actual)
269            },
270        )
271    }
272}