ra_ap_intern/
symbol.rs

1//! Attempt at flexible symbol interning, allowing to intern and free strings at runtime while also
2//! supporting compile time declaration of symbols that will never be freed.
3
4use std::{
5    fmt,
6    hash::{BuildHasher, BuildHasherDefault, Hash},
7    mem::{self, ManuallyDrop},
8    ptr::NonNull,
9    sync::OnceLock,
10};
11
12use dashmap::{DashMap, SharedValue};
13use hashbrown::raw::RawTable;
14use rustc_hash::FxHasher;
15use triomphe::Arc;
16
17pub mod symbols;
18
19// some asserts for layout compatibility
20const _: () = assert!(size_of::<Box<str>>() == size_of::<&str>());
21const _: () = assert!(align_of::<Box<str>>() == align_of::<&str>());
22
23const _: () = assert!(size_of::<Arc<Box<str>>>() == size_of::<&&str>());
24const _: () = assert!(align_of::<Arc<Box<str>>>() == align_of::<&&str>());
25
26const _: () = assert!(size_of::<*const *const str>() == size_of::<TaggedArcPtr>());
27const _: () = assert!(align_of::<*const *const str>() == align_of::<TaggedArcPtr>());
28
29const _: () = assert!(size_of::<Arc<Box<str>>>() == size_of::<TaggedArcPtr>());
30const _: () = assert!(align_of::<Arc<Box<str>>>() == align_of::<TaggedArcPtr>());
31
32/// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or
33/// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag
34/// in the LSB of the alignment niche.
35// Note, Ideally this would encode a `ThinArc<str>` and `ThinRef<str>`/`ThinConstPtr<str>` instead of the double indirection.
36#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
37struct TaggedArcPtr {
38    packed: NonNull<*const str>,
39}
40
41unsafe impl Send for TaggedArcPtr {}
42unsafe impl Sync for TaggedArcPtr {}
43
44impl TaggedArcPtr {
45    const BOOL_BITS: usize = true as usize;
46
47    const fn non_arc(r: &'static &'static str) -> Self {
48        assert!(align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS);
49        // SAFETY: The pointer is non-null as it is derived from a reference
50        // Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the
51        // packing stuff requires reading out the pointer to an integer which is not supported
52        // in const contexts, so here we make use of the fact that for the non-arc version the
53        // tag is false (0) and thus does not need touching the actual pointer value.ext)
54
55        let packed =
56            unsafe { NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) };
57        Self { packed }
58    }
59
60    fn arc(arc: Arc<Box<str>>) -> Self {
61        assert!(align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS);
62        Self {
63            packed: Self::pack_arc(
64                // Safety: `Arc::into_raw` always returns a non null pointer
65                unsafe { NonNull::new_unchecked(Arc::into_raw(arc).cast_mut().cast()) },
66            ),
67        }
68    }
69
70    /// Retrieves the tag.
71    ///
72    /// # Safety
73    ///
74    /// You can only drop the `Arc` if the instance is dropped.
75    #[inline]
76    pub(crate) unsafe fn try_as_arc_owned(self) -> Option<ManuallyDrop<Arc<Box<str>>>> {
77        // Unpack the tag from the alignment niche
78        let tag = self.packed.as_ptr().addr() & Self::BOOL_BITS;
79        if tag != 0 {
80            // Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc`
81            Some(ManuallyDrop::new(unsafe {
82                Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>())
83            }))
84        } else {
85            None
86        }
87    }
88
89    #[inline]
90    fn pack_arc(ptr: NonNull<*const str>) -> NonNull<*const str> {
91        let packed_tag = true as usize;
92
93        unsafe {
94            // Safety: The pointer is derived from a non-null and bit-oring it with true (1) will
95            // not make it null.
96            NonNull::new_unchecked(ptr.as_ptr().map_addr(|addr| addr | packed_tag))
97        }
98    }
99
100    #[inline]
101    pub(crate) fn pointer(self) -> NonNull<*const str> {
102        // SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes
103        unsafe {
104            NonNull::new_unchecked(self.packed.as_ptr().map_addr(|addr| addr & !Self::BOOL_BITS))
105        }
106    }
107
108    #[inline]
109    pub(crate) fn as_str(&self) -> &str {
110        // SAFETY: We always point to a pointer to a str no matter what variant is active
111        unsafe { *self.pointer().as_ptr().cast::<&str>() }
112    }
113}
114
115#[derive(PartialEq, Eq, Hash)]
116pub struct Symbol {
117    repr: TaggedArcPtr,
118}
119
120impl fmt::Debug for Symbol {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        self.as_str().fmt(f)
123    }
124}
125
126const _: () = assert!(size_of::<Symbol>() == size_of::<NonNull<()>>());
127const _: () = assert!(align_of::<Symbol>() == align_of::<NonNull<()>>());
128
129type Map = DashMap<Symbol, (), BuildHasherDefault<FxHasher>>;
130static MAP: OnceLock<Map> = OnceLock::new();
131
132impl Symbol {
133    pub fn intern(s: &str) -> Self {
134        let storage = MAP.get_or_init(symbols::prefill);
135        let (mut shard, hash) = Self::select_shard(storage, s);
136        // Atomically,
137        // - check if `obj` is already in the map
138        //   - if so, copy out its entry, conditionally bumping the backing Arc and return it
139        //   - if not, put it into a box and then into an Arc, insert it, bump the ref-count and return the copy
140        // This needs to be atomic (locking the shard) to avoid races with other thread, which could
141        // insert the same object between us looking it up and inserting it.
142        let bucket = match shard.find_or_find_insert_slot(
143            hash,
144            |(other, _)| other.as_str() == s,
145            |(x, _)| Self::hash(storage, x.as_str()),
146        ) {
147            Ok(bucket) => bucket,
148            // SAFETY: The slot came from `find_or_find_insert_slot()`, and the table wasn't modified since then.
149            Err(insert_slot) => unsafe {
150                shard.insert_in_slot(
151                    hash,
152                    insert_slot,
153                    (
154                        Symbol { repr: TaggedArcPtr::arc(Arc::new(Box::<str>::from(s))) },
155                        SharedValue::new(()),
156                    ),
157                )
158            },
159        };
160        // SAFETY: We just retrieved/inserted this bucket.
161        unsafe { bucket.as_ref().0.clone() }
162    }
163
164    pub fn integer(i: usize) -> Self {
165        match i {
166            0 => symbols::INTEGER_0,
167            1 => symbols::INTEGER_1,
168            2 => symbols::INTEGER_2,
169            3 => symbols::INTEGER_3,
170            4 => symbols::INTEGER_4,
171            5 => symbols::INTEGER_5,
172            6 => symbols::INTEGER_6,
173            7 => symbols::INTEGER_7,
174            8 => symbols::INTEGER_8,
175            9 => symbols::INTEGER_9,
176            10 => symbols::INTEGER_10,
177            11 => symbols::INTEGER_11,
178            12 => symbols::INTEGER_12,
179            13 => symbols::INTEGER_13,
180            14 => symbols::INTEGER_14,
181            15 => symbols::INTEGER_15,
182            i => Symbol::intern(&format!("{i}")),
183        }
184    }
185
186    pub fn empty() -> Self {
187        symbols::__empty
188    }
189
190    #[inline]
191    pub fn as_str(&self) -> &str {
192        self.repr.as_str()
193    }
194
195    #[inline]
196    fn select_shard(
197        storage: &'static Map,
198        s: &str,
199    ) -> (dashmap::RwLockWriteGuard<'static, RawTable<(Symbol, SharedValue<()>)>>, u64) {
200        let hash = Self::hash(storage, s);
201        let shard_idx = storage.determine_shard(hash as usize);
202        let shard = &storage.shards()[shard_idx];
203        (shard.write(), hash)
204    }
205
206    #[inline]
207    fn hash(storage: &'static Map, s: &str) -> u64 {
208        storage.hasher().hash_one(s)
209    }
210
211    #[cold]
212    fn drop_slow(arc: &Arc<Box<str>>) {
213        let storage = MAP.get_or_init(symbols::prefill);
214        let (mut shard, hash) = Self::select_shard(storage, arc);
215
216        match Arc::count(arc) {
217            0 | 1 => unreachable!(),
218            2 => (),
219            _ => {
220                // Another thread has interned another copy
221                return;
222            }
223        }
224
225        let s = &***arc;
226        let (ptr, _) = shard.remove_entry(hash, |(x, _)| x.as_str() == s).unwrap();
227        let ptr = ManuallyDrop::new(ptr);
228        // SAFETY: We're dropping, we have ownership.
229        ManuallyDrop::into_inner(unsafe { ptr.repr.try_as_arc_owned().unwrap() });
230        debug_assert_eq!(Arc::count(arc), 1);
231
232        // Shrink the backing storage if the shard is less than 50% occupied.
233        if shard.len() * 2 < shard.capacity() {
234            let len = shard.len();
235            shard.shrink_to(len, |(x, _)| Self::hash(storage, x.as_str()));
236        }
237    }
238}
239
240impl Drop for Symbol {
241    #[inline]
242    fn drop(&mut self) {
243        // SAFETY: We're dropping, we have ownership.
244        let Some(arc) = (unsafe { self.repr.try_as_arc_owned() }) else {
245            return;
246        };
247        // When the last `Ref` is dropped, remove the object from the global map.
248        if Arc::count(&arc) == 2 {
249            // Only `self` and the global map point to the object.
250
251            Self::drop_slow(&arc);
252        }
253        // decrement the ref count
254        ManuallyDrop::into_inner(arc);
255    }
256}
257
258impl Clone for Symbol {
259    fn clone(&self) -> Self {
260        Self { repr: increase_arc_refcount(self.repr) }
261    }
262}
263
264fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr {
265    // SAFETY: We're not dropping the `Arc`.
266    let Some(arc) = (unsafe { repr.try_as_arc_owned() }) else {
267        return repr;
268    };
269    // increase the ref count
270    mem::forget(Arc::clone(&arc));
271    repr
272}
273
274impl fmt::Display for Symbol {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        self.as_str().fmt(f)
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn smoke_test() {
286        Symbol::intern("isize");
287        let base_len = MAP.get().unwrap().len();
288        let hello = Symbol::intern("hello");
289        let world = Symbol::intern("world");
290        let more_worlds = world.clone();
291        let bang = Symbol::intern("!");
292        let q = Symbol::intern("?");
293        assert_eq!(MAP.get().unwrap().len(), base_len + 4);
294        let bang2 = Symbol::intern("!");
295        assert_eq!(MAP.get().unwrap().len(), base_len + 4);
296        drop(bang2);
297        assert_eq!(MAP.get().unwrap().len(), base_len + 4);
298        drop(q);
299        assert_eq!(MAP.get().unwrap().len(), base_len + 3);
300        let default = Symbol::intern("default");
301        let many_worlds = world.clone();
302        assert_eq!(MAP.get().unwrap().len(), base_len + 3);
303        assert_eq!(
304            "hello default world!",
305            format!("{} {} {}{}", hello.as_str(), default.as_str(), world.as_str(), bang.as_str())
306        );
307        drop(default);
308        assert_eq!(
309            "hello world!",
310            format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str())
311        );
312        drop(many_worlds);
313        drop(more_worlds);
314        drop(hello);
315        drop(world);
316        drop(bang);
317        assert_eq!(MAP.get().unwrap().len(), base_len);
318    }
319}