Skip to main content

any_intern/
any.rs

1use super::common::{self, Interned, RawInterned, UnsafeLock};
2use bumpalo::Bump;
3use hashbrown::{HashTable, hash_table::Entry};
4use std::{
5    alloc::Layout,
6    any::TypeId,
7    borrow,
8    cell::Cell,
9    hash::{BuildHasher, Hash},
10    mem, ptr,
11};
12
13/// A type-erased interner for storing and deduplicating values of a single type.
14///
15/// This interner is simply a wrapper of [`AnyInternSet`] with interior mutability. If you need a
16/// collection of interners for various types like a hash map of interners, then consider using the
17/// `AnyInterSet` with a container providing interior mutability such as [`ManualMutex`].
18///
19/// # Examples
20///
21/// ```
22/// use any_intern::AnyInterner;
23///
24/// #[derive(PartialEq, Eq, Hash, Debug)]
25/// struct A(u32);
26///
27/// let interner = AnyInterner::of::<A>();
28///
29/// unsafe {
30///     let a1 = interner.intern(A(42));
31///     let a2 = interner.intern(A(42));
32///     assert_eq!(a1, a2); // Same value, same reference
33///
34///     let a3 = interner.intern(A(99));
35///     assert_ne!(a1, a3); // Different values, different references
36/// }
37/// ```
38///
39/// # Safety
40///
41/// Many methods in `AnyInterner` are marked as `unsafe` because they rely on the caller to ensure
42/// that the correct type is used when interacting with the interner. Using an incorrect type can
43/// lead to undefined behavior.
44pub struct AnyInterner<S = fxhash::FxBuildHasher> {
45    inner: UnsafeLock<AnyInternSet<S>>,
46}
47
48impl AnyInterner {
49    pub fn of<K: 'static>() -> Self {
50        // Safety: Only one instance
51        let inner = unsafe { UnsafeLock::new(AnyInternSet::of::<K>()) };
52        Self { inner }
53    }
54}
55
56impl<S: BuildHasher> AnyInterner<S> {
57    pub fn with_hasher<K: 'static>(hash_builder: S) -> Self {
58        // Safety: Only one instance
59        let inner = unsafe { UnsafeLock::new(AnyInternSet::with_hasher::<K>(hash_builder)) };
60        Self { inner }
61    }
62
63    /// Returns number of values the interner contains.
64    pub fn len(&self) -> usize {
65        self.with_inner(|set| set.len())
66    }
67
68    /// Returns true if the interner is empty.
69    pub fn is_empty(&self) -> bool {
70        self.with_inner(|set| set.is_empty())
71    }
72
73    /// Stores a value in the interner, returning a reference to the interned value.
74    ///
75    /// This method inserts the given value into the interner if it does not already exist. If the
76    /// value already exists, a reference to the existing value is returned.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// use any_intern::AnyInterner;
82    ///
83    /// #[derive(PartialEq, Eq, Hash, Debug)]
84    /// struct A(u32);
85    ///
86    /// let interner = AnyInterner::of::<A>();
87    ///
88    /// unsafe {
89    ///     let a1 = interner.intern(A(42));
90    ///     let a2 = interner.intern(A(42));
91    ///     assert_eq!(a1, a2); // Same value, same reference
92    ///     assert_eq!(a1.raw().as_ptr(), a2.raw().as_ptr());
93    /// }
94    /// ```
95    ///
96    /// # Safety
97    ///
98    /// Undefined behavior if incorrect type `K` is given.
99    pub unsafe fn intern<K>(&self, value: K) -> Interned<'_, K>
100    where
101        K: Hash + Eq + 'static,
102    {
103        self.with_inner(|set| unsafe { set.intern(value) })
104    }
105
106    /// Stores a value in the interner, creating it only if it does not already exist.
107    ///
108    /// This method allows you to provide a key and a closure to generate the value. If the key
109    /// already exists in the interner, the closure is not called, and a reference to the existing
110    /// value is returned. If the key does not exist, the closure is called to create the value,
111    /// which is then stored in the interner.
112    ///
113    /// This method is useful when the value is expensive to compute, as it avoids unnecessary
114    /// computation if the value already exists.
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use any_intern::AnyInterner;
120    ///
121    ///
122    /// #[derive(PartialEq, Eq, Hash, Debug)]
123    /// struct A(i32);
124    ///
125    /// impl std::borrow::Borrow<i32> for A {
126    ///     fn borrow(&self) -> &i32 {
127    ///         &self.0
128    ///     }
129    /// }
130    ///
131    /// let interner = AnyInterner::of::<A>();
132    ///
133    /// unsafe {
134    ///     let a = interner.intern_with(&42, || A(42));
135    ///     assert_eq!(interner.len(), 1);
136    ///     assert_eq!(*a, A(42));
137    ///
138    ///     let b = interner.intern_with(&42, || A(99)); // Closure is not called
139    ///     assert_eq!(interner.len(), 1);
140    ///     assert_eq!(*b, A(42));
141    ///
142    ///     let c = interner.intern_with(&43, || A(43));
143    ///     assert_eq!(interner.len(), 2);
144    ///     assert_eq!(*c, A(43));
145    /// }
146    /// ```
147    ///
148    /// # Safety
149    ///
150    /// Undefined behavior if incorrect type `K` is given.
151    pub unsafe fn intern_with<K, Q, F>(&self, key: &Q, make_value: F) -> Interned<'_, K>
152    where
153        K: borrow::Borrow<Q> + 'static,
154        Q: Hash + Eq + ?Sized,
155        F: FnOnce() -> K,
156    {
157        self.with_inner(|set| unsafe { set.intern_with(key, make_value) })
158    }
159
160    /// Retrieves a reference to a value in the interner based on the provided key.
161    ///
162    /// This method checks if a value corresponding to the given key exists in the interner. If it
163    /// exists, a reference to the interned value is returned. Otherwise, `None` is returned.
164    ///
165    /// # Eaxmples
166    ///
167    /// ```
168    /// use any_intern::AnyInterner;
169    ///
170    /// let interner = AnyInterner::of::<i32>();
171    /// unsafe {
172    ///     interner.intern(42);
173    ///     assert_eq!(interner.get::<i32, _>(&42).as_deref(), Some(&42));
174    ///     assert!(interner.get::<i32, _>(&99).is_none());
175    /// }
176    /// ```
177    ///
178    /// # Safety
179    ///
180    /// Undefined behavior if incorrect type `K` is given.
181    pub unsafe fn get<K, Q>(&self, key: &Q) -> Option<Interned<'_, K>>
182    where
183        K: borrow::Borrow<Q> + 'static,
184        Q: Hash + Eq + ?Sized,
185    {
186        self.with_inner(|set| unsafe { set.get(key) })
187    }
188
189    /// Returns true if the interner contains values of the given type.
190    pub fn is_type_of<K: 'static>(&self) -> bool {
191        self.with_inner(|set| set.is_type_of::<K>())
192    }
193
194    /// Removes all items in the interner.
195    ///
196    /// Although the interner support interior mutability, clear method requires mutable access
197    /// to the interner to invalidate all [`Interned`]s referencing the interner.
198    pub fn clear(&mut self) {
199        self.with_inner(|set| set.clear())
200    }
201
202    fn with_inner<'this, F, R>(&'this self, f: F) -> R
203    where
204        F: FnOnce(&'this mut AnyInternSet<S>) -> R,
205        R: 'this,
206    {
207        // Safety: Mutex unlocking is paired with the locking.
208        unsafe {
209            let set = self.inner.lock().as_mut();
210            let ret = f(set);
211            self.inner.unlock();
212            ret
213        }
214    }
215}
216
217/// A type-erased interning set for storing and deduplicating values of a single type without
218/// interior mutability.
219///
220/// # Examples
221///
222/// ```
223/// use any_intern::AnyInternSet;
224///
225/// #[derive(PartialEq, Eq, Hash, Debug)]
226/// struct A(u32);
227///
228/// let mut set = AnyInternSet::of::<A>();
229///
230/// unsafe {
231///     let a1 = set.intern(A(42)).raw();
232///     let a2 = set.intern(A(42)).raw();
233///     assert_eq!(a1, a2); // Same value, same reference
234///
235///     let a3 = set.intern(A(99)).raw();
236///     assert_ne!(a1, a3); // Different values, different references
237/// }
238/// ```
239///
240/// # Safety
241///
242/// Many methods in `AnyInternSet` are marked as `unsafe` because they rely on the caller to ensure
243/// that the correct type is used when interacting with the interner. Using an incorrect type can
244/// lead to undefined behavior.
245pub struct AnyInternSet<S = fxhash::FxBuildHasher> {
246    arena: AnyArena,
247    set: HashTable<RawInterned>,
248    hash_builder: S,
249}
250
251impl AnyInternSet {
252    pub fn of<K: 'static>() -> Self {
253        Self {
254            arena: AnyArena::of::<K>(),
255            set: HashTable::new(),
256            hash_builder: Default::default(),
257        }
258    }
259}
260
261impl<S: Default> AnyInternSet<S> {
262    pub fn default_of<K: 'static>() -> Self {
263        Self {
264            arena: AnyArena::of::<K>(),
265            set: HashTable::new(),
266            hash_builder: Default::default(),
267        }
268    }
269}
270
271impl<S: BuildHasher> AnyInternSet<S> {
272    pub fn with_hasher<K: 'static>(hash_builder: S) -> Self {
273        Self {
274            arena: AnyArena::of::<K>(),
275            set: HashTable::new(),
276            hash_builder,
277        }
278    }
279
280    /// Returns number of values in the set.
281    pub const fn len(&self) -> usize {
282        self.arena.len()
283    }
284
285    /// Returns true if the set is empty.
286    pub const fn is_empty(&self) -> bool {
287        self.len() == 0
288    }
289
290    /// Stores a value in the set, returning a reference to the value.
291    ///
292    /// This method inserts the given value into the set if it does not already exist. If the value
293    /// already exists, a reference to the existing value is returned.
294    ///
295    /// # Examples
296    ///
297    /// ```
298    /// use any_intern::AnyInternSet;
299    ///
300    /// #[derive(PartialEq, Eq, Hash, Debug)]
301    /// struct A(u32);
302    ///
303    /// let mut set = AnyInternSet::of::<A>();
304    ///
305    /// unsafe {
306    ///     let a1 = set.intern(A(42)).raw();
307    ///     let a2 = set.intern(A(42)).raw();
308    ///     assert_eq!(a1, a2); // Same value, same reference
309    /// }
310    /// ```
311    ///
312    /// # Safety
313    ///
314    /// Undefined behavior if incorrect type `K` is given.
315    pub unsafe fn intern<K>(&mut self, value: K) -> Interned<'_, K>
316    where
317        K: Hash + Eq + 'static,
318    {
319        debug_assert!(self.is_type_of::<K>());
320
321        unsafe {
322            let hash = Self::hash(&self.hash_builder, &value);
323            let eq = Self::table_eq::<K, K>(&value);
324            let hasher = Self::table_hasher::<K, K>(&self.hash_builder);
325            match self.set.entry(hash, eq, hasher) {
326                Entry::Occupied(entry) => Interned::from_erased_raw(*entry.get()),
327                Entry::Vacant(entry) => {
328                    let ref_ = self.arena.alloc(value);
329                    let interned = Interned::unique(ref_);
330                    let raw = interned.erased_raw();
331                    entry.insert(raw);
332                    interned
333                }
334            }
335        }
336    }
337
338    /// Stores a value in the set, creating it only if it does not already exist.
339    ///
340    /// This method allows you to provide a key and a closure to generate the value. If the key
341    /// already exists in the set, the closure is not called, and a reference to the existing value
342    /// is returned. If the key does not exist, the closure is called to create the value, which is
343    /// then stored in the set.
344    ///
345    /// This method is useful when the value is expensive to compute, as it avoids unnecessary
346    /// computation if the value already exists.
347    ///
348    /// # Examples
349    ///
350    /// ```
351    /// use any_intern::AnyInternSet;
352    ///
353    ///
354    /// #[derive(PartialEq, Eq, Hash, Debug)]
355    /// struct A(i32);
356    ///
357    /// impl std::borrow::Borrow<i32> for A {
358    ///     fn borrow(&self) -> &i32 {
359    ///         &self.0
360    ///     }
361    /// }
362    ///
363    /// let mut set = AnyInternSet::of::<A>();
364    ///
365    /// unsafe {
366    ///     let a = set.intern_with(&42, || A(42));
367    ///     assert_eq!(*a, A(42));
368    ///     assert_eq!(set.len(), 1);
369    ///
370    ///     let b = set.intern_with(&42, || A(99)); // Closure is not called
371    ///     assert_eq!(*b, A(42));
372    ///     assert_eq!(set.len(), 1);
373    ///
374    ///     let c = set.intern_with(&43, || A(43));
375    ///     assert_eq!(*c, A(43));
376    ///     assert_eq!(set.len(), 2);
377    /// }
378    /// ```
379    ///
380    /// # Safety
381    ///
382    /// Undefined behavior if incorrect type `K` is given.
383    pub unsafe fn intern_with<K, Q, F>(&mut self, key: &Q, make_value: F) -> Interned<'_, K>
384    where
385        K: borrow::Borrow<Q> + 'static,
386        Q: Hash + Eq + ?Sized,
387        F: FnOnce() -> K,
388    {
389        debug_assert!(self.is_type_of::<K>());
390
391        unsafe {
392            let hash = Self::hash(&self.hash_builder, key);
393            let eq = Self::table_eq::<K, Q>(key);
394            let hasher = Self::table_hasher::<K, Q>(&self.hash_builder);
395            match self.set.entry(hash, eq, hasher) {
396                Entry::Occupied(entry) => Interned::from_erased_raw(*entry.get()),
397                Entry::Vacant(entry) => {
398                    let value = make_value();
399                    let ref_ = self.arena.alloc(value);
400                    let interned = Interned::unique(ref_);
401                    let raw = interned.erased_raw();
402                    entry.insert(raw);
403                    interned
404                }
405            }
406        }
407    }
408
409    /// Retrieves a reference to a value in the set based on the provided key.
410    ///
411    /// This method checks if a value corresponding to the given key exists in the set. If it
412    /// exists, a reference to the value is returned. Otherwise, `None` is returned.
413    ///
414    /// # Eaxmples
415    ///
416    /// ```
417    /// use any_intern::AnyInternSet;
418    ///
419    /// let mut set = AnyInternSet::of::<i32>();
420    /// unsafe {
421    ///     set.intern(42);
422    ///     assert_eq!(set.get::<i32, _>(&42).as_deref(), Some(&42));
423    ///     assert!(set.get::<i32, _>(&99).is_none());
424    /// }
425    /// ```
426    ///
427    /// # Safety
428    ///
429    /// Undefined behavior if incorrect type `K` is given.
430    pub unsafe fn get<K, Q>(&self, key: &Q) -> Option<Interned<'_, K>>
431    where
432        K: borrow::Borrow<Q> + 'static,
433        Q: Hash + Eq + ?Sized,
434    {
435        debug_assert!(self.is_type_of::<K>());
436
437        unsafe {
438            let hash = Self::hash(&self.hash_builder, key);
439            let eq = Self::table_eq::<K, Q>(key);
440            self.set
441                .find(hash, eq)
442                .map(|raw| Interned::from_erased_raw(*raw))
443        }
444    }
445
446    /// Returns true if the set contains values of the given type.
447    pub fn is_type_of<K: 'static>(&self) -> bool {
448        self.arena.is_type_of::<K>()
449    }
450
451    /// Removes all items in the set.
452    pub fn clear(&mut self) {
453        self.arena.clear();
454        self.set.clear();
455    }
456
457    /// Returns `eq` closure that is used for some methods on the [`HashTable`].
458    ///
459    /// # Safety
460    ///
461    /// Undefined behavior if incorrect type `K` is given.
462    unsafe fn table_eq<K, Q>(key: &Q) -> impl FnMut(&RawInterned) -> bool
463    where
464        K: borrow::Borrow<Q>,
465        Q: Hash + Eq + ?Sized,
466    {
467        move |entry: &RawInterned| unsafe {
468            let value = entry.cast::<K>().as_ref();
469            value.borrow() == key
470        }
471    }
472
473    /// Returns `hasher` closure that is used for some methods on the [`HashTable`].
474    ///
475    /// # Safety
476    ///
477    /// Undefined behavior if incorrect type `K` is given.
478    unsafe fn table_hasher<K, Q>(hash_builder: &S) -> impl Fn(&RawInterned) -> u64
479    where
480        K: borrow::Borrow<Q>,
481        Q: Hash + Eq + ?Sized,
482    {
483        |entry: &RawInterned| unsafe {
484            let value = entry.cast::<K>().as_ref();
485            Self::hash(hash_builder, value.borrow())
486        }
487    }
488
489    fn hash<K: Hash + ?Sized>(hash_builder: &S, value: &K) -> u64 {
490        hash_builder.hash_one(value)
491    }
492}
493
494pub struct AnyArena {
495    bump: Bump,
496    ty: TypeId,
497    stride: usize,
498    raw_drop_slice: Option<unsafe fn(*mut u8, usize)>,
499    len: Cell<usize>,
500}
501
502impl AnyArena {
503    pub fn of<T: 'static>() -> Self {
504        Self {
505            bump: Bump::new(),
506            ty: TypeId::of::<T>(),
507            stride: Layout::new::<T>().pad_to_align().size(),
508            raw_drop_slice: if mem::needs_drop::<T>() {
509                Some(common::cast_then_drop_slice::<T>)
510            } else {
511                None
512            },
513            len: Cell::new(0),
514        }
515    }
516
517    pub fn is_type_of<T: 'static>(&self) -> bool {
518        TypeId::of::<T>() == self.ty
519    }
520
521    /// Returns number of elements in this arena.
522    pub const fn len(&self) -> usize {
523        self.len.get()
524    }
525
526    pub const fn is_empty(&self) -> bool {
527        self.len() == 0
528    }
529
530    pub fn alloc<T: 'static>(&self, value: T) -> &mut T {
531        debug_assert!(self.is_type_of::<T>());
532
533        self.len.set(self.len() + 1);
534        self.bump.alloc(value)
535    }
536
537    pub fn clear(&mut self) {
538        self.drop_all();
539        self.bump.reset();
540        self.len.set(0);
541    }
542
543    fn drop_all(&mut self) {
544        if let Some(raw_drop_slice) = self.raw_drop_slice {
545            if self.stride > 0 {
546                for chunk in self.bump.iter_allocated_chunks() {
547                    // Chunk would not be divisible by the `stride` especially when the stride is
548                    // greater than 16. In that case, we should ignore the remainder.
549                    let num_elems = chunk.len() / self.stride;
550                    let ptr = chunk.as_ptr().cast::<u8>().cast_mut();
551                    unsafe {
552                        raw_drop_slice(ptr, num_elems);
553                    }
554                }
555            } else {
556                let ptr = ptr::dangling_mut::<u8>();
557                unsafe {
558                    raw_drop_slice(ptr, self.len());
559                }
560            }
561        }
562    }
563}
564
565impl Drop for AnyArena {
566    fn drop(&mut self) {
567        self.drop_all();
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn test_any_interner() {
577        #[derive(PartialEq, Eq, Hash, Debug)]
578        struct A(i32);
579
580        let interner = AnyInterner::of::<A>();
581
582        unsafe {
583            let a = interner.intern(A(0));
584            let b = interner.intern(A(0));
585            let c = interner.intern(A(1));
586            assert_eq!(a, b);
587            assert_ne!(a, c);
588        }
589    }
590
591    #[test]
592    fn test_arena() {
593        test_arena_alloc();
594        test_arena_drop();
595    }
596
597    fn test_arena_alloc() {
598        const START: u32 = 0;
599        const END: u32 = 100;
600        const EXPECTED: u32 = (END + START) * (END - START + 1) / 2;
601
602        let arena = AnyArena::of::<u32>();
603        let mut refs = Vec::new();
604        for i in START..=END {
605            let ref_ = arena.alloc(i);
606            refs.push(ref_);
607        }
608        let acc = refs.into_iter().map(|ref_| *ref_).sum::<u32>();
609        assert_eq!(acc, EXPECTED);
610    }
611
612    fn test_arena_drop() {
613        macro_rules! test {
614            ($arr_len:literal, $align:literal) => {{
615                thread_local! {
616                    static SUM: Cell<u32> = Cell::new(0);
617                    static CNT: Cell<u32> = Cell::new(0);
618                }
619
620                #[repr(align($align))]
621                struct A([u8; $arr_len]);
622
623                // Restricted by `u8` and `A::new()`.
624                const _: () = const { assert!($arr_len < 256) };
625
626                impl A {
627                    fn new() -> Self {
628                        Self(std::array::from_fn(|i| i as u8))
629                    }
630
631                    fn sum() -> u32 {
632                        ($arr_len - 1) * $arr_len / 2
633                    }
634                }
635
636                impl Drop for A {
637                    fn drop(&mut self) {
638                        let sum = self.0.iter().map(|n| *n as u32).sum::<u32>();
639                        SUM.set(SUM.get() + sum);
640                        CNT.set(CNT.get() + 1);
641                    }
642                }
643
644                struct Zst;
645
646                impl Drop for Zst {
647                    fn drop(&mut self) {
648                        CNT.set(CNT.get() + 1);
649                    }
650                }
651
652                const REPEAT: u32 = 10;
653
654                // === Non-ZST type ===
655
656                let arena = AnyArena::of::<A>();
657                for _ in 0..REPEAT {
658                    arena.alloc(A::new());
659                }
660                drop(arena);
661
662                assert_eq!(SUM.get(), A::sum() * REPEAT);
663                assert_eq!(CNT.get(), REPEAT);
664                SUM.set(0);
665                CNT.set(0);
666
667                // === ZST type ===
668
669                let arena = AnyArena::of::<Zst>();
670                for _ in 0..REPEAT {
671                    arena.alloc(Zst);
672                }
673                drop(arena);
674
675                assert_eq!(CNT.get(), REPEAT);
676                CNT.set(0);
677            }};
678        }
679
680        // Array len, align
681        test!(1, 1);
682        test!(1, 2);
683        test!(1, 4);
684        test!(1, 8);
685        test!(1, 16);
686        test!(1, 32);
687        test!(1, 64);
688        test!(1, 128);
689        test!(1, 256);
690
691        test!(100, 1);
692        test!(100, 2);
693        test!(100, 4);
694        test!(100, 8);
695        test!(100, 16);
696        test!(100, 32);
697        test!(100, 64);
698        test!(100, 128);
699        test!(100, 256);
700    }
701}