cursed_collections/
symbol_table.rs1use ::alloc::{alloc, string::String, vec};
2use core::borrow::Borrow;
3use core::hash::{Hash, Hasher};
4use core::{cell, fmt, hash, marker, mem, ptr, slice, str};
5use hashbrown::HashSet;
6
7const LARGE_SYMBOL_THRESHOLD: usize = 1 << 9;
8const SEGMENT_CAPACITY: usize = 1 << 12;
9
10#[allow(clippy::assertions_on_constants)]
11const _: () = assert!(
12 LARGE_SYMBOL_THRESHOLD < SEGMENT_CAPACITY,
13 "a small symbol must always fit in a fresh segment",
14);
15
16#[repr(transparent)]
17struct SymbolKey(*const str);
18
19impl PartialEq for SymbolKey {
20 fn eq(&self, other: &Self) -> bool {
21 unsafe { *self.0 == *other.0 }
22 }
23}
24
25impl Eq for SymbolKey {}
26
27impl hash::Hash for SymbolKey {
28 fn hash<H: hash::Hasher>(&self, state: &mut H) {
29 unsafe { (*self.0).hash(state) }
30 }
31}
32
33#[repr(transparent)]
38#[derive(Copy, Clone)]
39pub struct Symbol<'table> {
40 ptr: *const str,
41 _p: marker::PhantomData<&'table str>,
42}
43
44impl<'table> Symbol<'table> {
45 fn new(ptr: *const str) -> Self {
46 Self {
47 ptr,
48 _p: marker::PhantomData,
49 }
50 }
51
52 pub fn as_str(self) -> &'table str {
53 unsafe { &*self.ptr }
54 }
55}
56
57impl<'table> PartialEq for Symbol<'table> {
58 fn eq(&self, other: &Self) -> bool {
59 ptr::eq(self.ptr, other.ptr)
60 }
61}
62
63impl<'table> Eq for Symbol<'table> {}
64
65impl<'table> Hash for Symbol<'table> {
66 fn hash<H: Hasher>(&self, state: &mut H) {
67 self.as_str().hash(state)
68 }
69}
70
71impl<'table> AsRef<str> for Symbol<'table> {
72 fn as_ref(&self) -> &str {
73 self.as_str()
74 }
75}
76
77impl<'table> fmt::Debug for Symbol<'table> {
78 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79 write!(f, "{:?}@{:p}", self.as_str(), self.ptr)
80 }
81}
82
83impl<'table> fmt::Display for Symbol<'table> {
84 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85 f.write_str(self.as_str())
86 }
87}
88
89impl<'table> Borrow<str> for Symbol<'table> {
90 fn borrow(&self) -> &str {
91 self.as_ref()
92 }
93}
94
95const BUFFER_LAYOUT: alloc::Layout = alloc::Layout::new::<[u8; SEGMENT_CAPACITY]>();
96
97pub struct SymbolTable {
100 lookup: cell::UnsafeCell<HashSet<SymbolKey>>,
101 small_symbols: cell::UnsafeCell<vec::Vec<*const u8>>,
102 large_symbols: cell::UnsafeCell<vec::Vec<(*const u8, usize, usize)>>,
103 tail: cell::Cell<*mut u8>,
104 tail_offset: cell::Cell<usize>,
105}
106
107impl SymbolTable {
108 pub fn new() -> Self {
112 unsafe {
113 Self {
114 lookup: cell::UnsafeCell::new(HashSet::new()),
115 small_symbols: cell::UnsafeCell::new(vec![]),
116 large_symbols: cell::UnsafeCell::new(vec![]),
117 tail: cell::Cell::new(alloc::alloc(BUFFER_LAYOUT)),
118 tail_offset: cell::Cell::new(0),
119 }
120 }
121 }
122
123 pub fn intern(&self, text: impl Into<String> + AsRef<str>) -> Symbol {
133 unsafe {
134 let lookup = &mut *self.lookup.get();
135 if let Some(&SymbolKey(ptr)) = lookup.get(&SymbolKey(text.as_ref())) {
136 return Symbol::new(ptr);
137 }
138
139 let symbol @ Symbol { ptr, .. } = self.gensym(text);
140 lookup.insert(SymbolKey(ptr));
141 symbol
142 }
143 }
144
145 pub fn gensym(&self, text: impl Into<String> + AsRef<str>) -> Symbol {
161 unsafe {
162 let text_len = text.as_ref().len();
163 if text_len >= LARGE_SYMBOL_THRESHOLD {
164 let large_symbol = mem::ManuallyDrop::new(text.into());
165 let ptr = large_symbol.as_ptr();
166 let size = large_symbol.len();
167 (*self.large_symbols.get()).push((ptr, size, large_symbol.capacity()));
168 return Symbol::new(str::from_utf8_unchecked(slice::from_raw_parts(ptr, size)));
169 }
170
171 if text_len + self.tail_offset.get() > SEGMENT_CAPACITY {
172 self.tail_offset.set(0);
173 let prev_tail = self.tail.replace(alloc::alloc(BUFFER_LAYOUT));
174 (*self.small_symbols.get()).push(prev_tail);
175 }
176
177 let tail_offset = self.tail_offset.get();
178 let dst = self.tail.get().add(tail_offset);
179 ptr::copy_nonoverlapping(text.as_ref().as_ptr(), dst, text_len);
180 self.tail_offset.replace(tail_offset + text_len);
181 Symbol::new(str::from_utf8_unchecked(slice::from_raw_parts(
182 dst, text_len,
183 )))
184 }
185 }
186}
187
188impl Drop for SymbolTable {
189 fn drop(&mut self) {
190 unsafe {
191 alloc::dealloc(self.tail.get(), BUFFER_LAYOUT);
192 for segment in self.small_symbols.get_mut().drain(..) {
193 alloc::dealloc(segment as *mut _, BUFFER_LAYOUT);
194 }
195 for (ptr, size, capacity) in self.large_symbols.get_mut().drain(..) {
196 String::from_raw_parts(ptr as *mut _, size, capacity);
197 }
198 }
199 }
200}
201
202impl Default for SymbolTable {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::{Symbol, SymbolTable, LARGE_SYMBOL_THRESHOLD};
211 use quickcheck_macros::quickcheck;
212 use std::{iter, ptr};
213
214 #[test]
215 fn two_symbols_are_different() {
216 let table = SymbolTable::new();
217 assert_ne!(table.intern("laura"), table.intern("maddy"));
218 }
219
220 #[test]
221 fn empty_symbol_is_different_from_other_symbols() {
222 {
223 let table = SymbolTable::new();
224 assert_ne!(table.intern(""), table.intern("laura"));
225 }
226 {
227 let table = SymbolTable::new();
228 assert_ne!(table.intern("laura"), table.intern(""));
229 }
230 }
231
232 #[test]
233 fn interning_a_single_null_byte_works() {
234 let table = SymbolTable::new();
235 assert_eq!(table.intern("\0"), table.intern("\0"));
236 }
237
238 #[test]
239 fn interning_a_large_string() {
240 let text = iter::repeat('a')
241 .take(2 * LARGE_SYMBOL_THRESHOLD + 7)
242 .collect::<String>();
243 let table = SymbolTable::new();
244 assert_eq!(table.intern(&text), table.intern(text));
245 }
246
247 #[test]
248 fn interning_can_refer_to_previous_segment() {
249 let table = SymbolTable::new();
250 let symbol = table.intern("laura");
251 for c in 'a'..'z' {
252 table.intern(iter::repeat(c).take(234).collect::<String>());
253 }
254 assert_eq!(symbol, table.intern("laura"));
255 }
256
257 #[quickcheck]
258 #[cfg_attr(miri, ignore)]
259 fn interning_twice_returns_same_symbol(texts: Vec<String>) -> bool {
260 let table = SymbolTable::new();
261 let symbols = texts
262 .iter()
263 .map(|text| table.intern(text))
264 .collect::<Vec<_>>();
265 symbols.into_iter().zip(texts.into_iter()).into_iter().all(
266 |(Symbol { ptr: expected, .. }, text)| {
267 let Symbol { ptr: actual, .. } = table.intern(text);
268 ptr::eq(expected, actual)
269 },
270 )
271 }
272}